feat(ai-security-guard): structured x_higress deny response, error-path metrics, and AI logging (#3894)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: rinfx <yucheng.lxr@alibaba-inc.com>
This commit is contained in:
JianweiWang
2026-05-29 10:45:10 +08:00
committed by GitHub
parent 385f8d8b4e
commit c21a38e783
14 changed files with 2181 additions and 195 deletions

View File

@@ -34,6 +34,7 @@ description: 阿里云内容安全检测
| `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 |
| `denyMessage` | string | optional | openai格式的流式/非流式响应 | 指定内容非法时的响应内容 |
| `protocol` | string | optional | openai | 协议格式非openai协议填`original` |
| `openAIDenyResponseFormat` | string | optional | legacy | OpenAI 包装拒答的响应形态,取值为 `legacy``structured`。默认 `legacy` 保持历史兼容;配置为 `structured` 时在 `choices[0].x_higress_guardrail` 输出结构化拦截详情 |
| `contentModerationLevelBar` | string | optional | max | 内容合规检测拦截风险等级,取值为 `max`, `high`, `medium` or `low` |
| `promptAttackLevelBar` | string | optional | max | 提示词攻击检测拦截风险等级,取值为 `max`, `high`, `medium` or `low` |
| `sensitiveDataLevelBar` | string | optional | S4 | 敏感内容检测拦截风险等级,取值为 `S4`, `S3`, `S2` or `S1` |
@@ -47,19 +48,18 @@ description: 阿里云内容安全检测
### 拒绝响应结构
内容被拦截时,插件(`MultiModalGuard` action统一返回以下结构化 JSON 对象,各协议的承载位置如下:
内容被拦截时,插件(`MultiModalGuard` action会构造以下结构化 JSON 对象`protocol: original`、MCP 与图像生成路径直接或间接返回该对象OpenAI 文本生成包装路径默认保持历史兼容形态,只有配置 `openAIDenyResponseFormat: structured` 时才会把该对象嵌入到 OpenAI 响应中。
```json
{
"code": 200,
"denyMessage": "很抱歉,我无法回答您的问题",
"blockedDetails": [
{
"Type": "contentModeration",
"Level": "high",
"Suggestion": "block"
"type": "contentModeration",
"level": "high"
}
],
"requestId": "AAAAAA-BBBB-CCCC-DDDD-EEEEEEE****",
"guardCode": 200
]
}
```
@@ -67,22 +67,26 @@ description: 阿里云内容安全检测
| 字段 | 类型 | 说明 |
| --- | --- | --- |
| `blockedDetails` | array | 命中拦截的维度明细;若安全服务返回明细,则根据顶层风险信号自动合成 |
| `blockedDetails[].Type` | string | 风险类型:`contentModeration` / `promptAttack` / `sensitiveData` / `maliciousUrl` / `modelHallucination` |
| `blockedDetails[].Level` | string | 风险等级:`high` / `medium` / `low` |
| `blockedDetails[].Suggestion` | string | 安全服务建议操作,通常为 `block` |
| `requestId` | string | 安全服务的请求 ID用于追踪 |
| `guardCode` | int | 安全服务返回的业务码(非 HTTP 状态码,成功检测时为 `200` |
| `code` | int | 在 `text_generation`(OpenAI 包装) 与 `image_generation` 路径下为网关返回的 HTTP 状态码,取自 `denyCode`(默认 `200`);在 `protocol=original``mcp` 路径下为安全服务返回的业务码(`Response.Code`,成功检测时为 `200` |
| `denyMessage` | string | 人类可读拦截文案。OpenAI 包装路径下始终存在,取 `denyMessage`(默认 `很抱歉,我无法回答您的问题`)`protocol=original` / `image_generation` / `mcp` 路径下取 `denyMessage`,未配置时省略该字段(`omitempty` |
| `blockedDetails` | array | 命中拦截的维度明细;若安全服务未返回 `Detail`,则根据顶层 `RiskLevel`/`AttackLevel` 自动合成。命中维度为空时返回 `[]` |
| `blockedDetails[].type` | string | 风险类型:`contentModeration` / `promptAttack` / `sensitiveData` / `maliciousUrl` / `modelHallucination` / `customLabel` |
| `blockedDetails[].level` | string | 风险等级:`high` / `medium` / `low`;敏感数据为 `S1``S4` |
> 说明:当前实现的拒答 body 仅包含上述字段。不输出安全服务的 `RequestId`、单条 `Suggestion` 与原始业务码(`guardCode`);安全服务的 `RequestId` 通过 AI 日志 `safecheck_request_ids` 字段暴露(见下文 AI Log 章节)。
各协议承载位置:
- **`text_generation`OpenAI 非流式)**:上述结构体序列化为 JSON 字符串后放入 `choices[0].message.content`
- **`text_generation`OpenAI 流式 SSE**:同上,放入首个 chunk 的 `delta.content`
- **`text_generation``protocol=original`**:上述结构体直接作为 JSON 响应 body 返回
- **`text_generation`OpenAI,默认 `legacy`**:不输出 `x_higress_guardrail` 或历史 `x_higress` 字段;`choices[0].message.content` / 首帧 `delta.content` 保持历史内容形态RiskBlock 为 JSON 字符串mask fallback 为拒答文案),`finish_reason``"stop"`,流式响应仍以 `data: [DONE]` 结束
- **`text_generation`OpenAI`structured` 非流式)**`choices[0].message.content` 承载可读拦截文案(即 `denyMessage`,未配置时默认为 `很抱歉,我无法回答您的问题`);上述结构体作为嵌入对象放入 `choices[0].x_higress_guardrail`(不是 JSON 字符串)
- **`text_generation`OpenAI`structured` 流式 SSE**:首帧 `delta.content` 承载可读拦截文案;上述结构体仅在最后一个 chunk 中作为嵌入对象放入 `choices[0].x_higress_guardrail`,随后以 `data: [DONE]` 结束流
- **`text_generation``protocol=original`**:上述结构体直接作为 JSON 响应 body 返回(不包 OpenAI 外壳,不新增 `x_higress_guardrail`)
- **`image_generation`**:上述结构体直接作为 JSON 响应 body 返回HTTP 403
- **`mcp`JSON-RPC**:上述结构体序列化为 JSON 字符串后放入 `error.message`
- **`mcp`SSE**:同上,通过 SSE 事件返回
`openAIDenyResponseFormat` 只影响 OpenAI 包装拒答的 body 形态拦截判断、fail-open 行为、metric 与 AI Log 字段不随该配置变化。该字段只能配置在插件全局,不能放入 `consumerRiskLevel`
补充说明一下内容合规检测、提示词攻击检测、敏感内容检测三种风险的四个等级:
- 对于内容合规检测、提示词攻击检测:
@@ -150,6 +154,14 @@ checkRequest: true
checkResponse: true
```
### 配置 OpenAI 结构化拒答
默认 `openAIDenyResponseFormat: legacy` 保持历史响应形态。若需要在 OpenAI 响应中输出结构化拦截详情,可配置:
```yaml
openAIDenyResponseFormat: structured
```
### 使用临时安全凭证
```yaml
@@ -247,11 +259,61 @@ ai-security-guard 插件提供了以下监控指标:
- `ai_sec_request_deny`: 请求内容安全检测失败请求数
- `ai_sec_response_deny`: 模型回答安全检测失败请求数
#### 图像响应阶段 metric/ai_log 字段重命名(过渡期)
历史上图像生成插件(`lvwang/multi_modal_guard/image/openai.go``lvwang/multi_modal_guard/image/qwen.go`)在**响应阶段**命中风险时错误地写入了请求阶段字段。本次版本修正了语义,并在 1~2 个发版周期内保留**双写过渡**
| 行为 | 旧值(错误,将在后续版本移除) | 新值(推荐) |
| --- | --- | --- |
| 计数器(deny) | `ai_sec_request_deny` | `ai_sec_response_deny` |
| ai_log 耗时(pass + deny) | `safecheck_request_rt` | `safecheck_response_rt` |
| ai_log 状态(deny) | `safecheck_status="reqeust deny"`(典型拼写错误,**本次直接废弃,不再写入** | `safecheck_status="response deny"` |
过渡期内图像响应阶段会同时写入新旧两组 `*_deny` 计数器和 `safecheck_*_rt` 字段;`safecheck_status` 只写新值。看板与告警请尽快切换到 `response_*` 字段名;当前依赖 `reqeust deny`(拼写错误版本)状态串的图像响应告警需要立即改为 `response deny`
### Trace
如果开启了链路追踪ai-security-guard 插件会在请求 span 中添加以下 attributes:
- `ai_sec_risklabel`: 表示请求命中的风险类型
- `ai_sec_deny_phase`: 表示请求被检测到风险的阶段取值为request或者response
### AI Log
ai-security-guard 插件会将每次提交给内容安全服务的检测结果写入 AI 访问日志,用于将网关日志和阿里云内容安全请求关联起来:
| 字段 | 类型 | 说明 |
| --- | --- | --- |
| `safecheck_requests` | array | 检测提交事件数组,每个元素为 `{"requestId"?: string, "phase": string, "modality": string, "result": string}` |
| `safecheck_request_ids` | array | 当前网关请求内所有有效内容安全 `RequestId`,按提交完成顺序保留,不去重、不截断 |
| `safecheck_request_id` | string | 最新一个有效内容安全 `RequestId`,用于兼容只读取单值的日志消费方 |
| `safecheck_status` | string | 历史兼容字段,反映本次网关请求最后一次状态变更的语义(详见下方枚举) |
| `safecheck_request_rt` / `safecheck_response_rt` | int | 请求/响应阶段安全检测的耗时(毫秒) |
| `safecheck_riskLabel` / `safecheck_riskWords` | string | 命中风险时的风险标签与风险词(取自安全服务返回的第一个命中结果) |
`safecheck_requests[].phase` 取值为 `request``response``modality` 取值为 `text``image``mcp``result` 表示**该次提交事件本身的处理结果**(而非网关最终对外动作),取值与含义如下:
| `result` 取值 | 含义 |
| --- | --- |
| `pass` | 该次提交检测通过 |
| `deny` | 该次提交命中风险,网关已对外返回拒答 |
| `mask` | 该次提交命中风险且 `Action=Mask`,安全服务返回了脱敏文本并用于改写请求体 |
| `error` | 该次提交本身处理失败HTTP 非 200、业务 Code 非 200、反序列化失败、构造拒答响应失败、调用内容安全服务失败等。错误发生在**响应阶段流式回调**且原因是构造拒答响应失败时,网关会 fail-open直接放行上游缓冲内容此时 `safecheck_status=build_fallback_pass`,对应事件 `result=error` 表示这次安全提交未完成 |
只有安全服务响应中的 `RequestId` 是 JSON 字符串且 `strings.TrimSpace(RequestId) != ""` 时,才会写入 `requestId``safecheck_request_ids``safecheck_request_id`;缺失、空字符串、空白字符串或非字符串值不会写入空占位。
每一次提交尝试都会生成一个 `safecheck_requests` 事件,包括 HTTP 非 200、业务失败码以及调用内容安全服务失败等错误场景错误结果会记录为 `result=error`。需要精确审计多次提交、流式分段或图片多次检测时,应优先使用 `safecheck_requests`
`safecheck_status` 枚举(历史字段,按"最后一次状态变更"覆盖,存在多次提交时仅保留最后一次的语义)
| `safecheck_status` 取值 | 含义 |
| --- | --- |
| `request pass` | 请求阶段所有提交均通过 |
| `request mask` | 请求阶段命中 mask请求体已被脱敏文本改写 |
| `reqeust deny` | 请求阶段命中风险,网关返回拒答(注:拼写为 `reqeust`,沿用历史,保持向后兼容) |
| `request error` | 请求阶段安全提交本身失败HTTP/反序列化/调用安全服务等),网关 fail-open 放行 |
| `response pass` | 响应阶段所有提交均通过 |
| `response deny` | 响应阶段命中风险,网关返回拒答 |
| `response error` | 响应阶段安全提交本身失败,网关 fail-open 放行 |
| `build_fallback_pass` | 响应阶段流式回调里构造拒答响应失败,网关 fail-open 直接放行上游缓冲内容 |
## 请求示例
```bash
curl http://localhost/v1/chat/completions \
@@ -267,25 +329,39 @@ curl http://localhost/v1/chat/completions \
}'
```
请求内容会被发送到阿里云内容安全服务进行检测如果请求内容检测结果为非法,网关将返回形如以下的回答:
当配置 `openAIDenyResponseFormat: structured` 时,请求内容会被发送到阿里云内容安全服务进行检测如果请求内容检测结果为非法,网关将返回形如以下的回答:
```json
{
"id": "chatcmpl-AAy3hK1dE4ODaegbGOMoC9VY4Sizv",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-4o-mini",
"system_fingerprint": "fp_44709d6fcb",
"created": 1727078400,
"model": "from-security-guard",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "作为一名人工智能助手,我不能提供涉及色情、暴力、政治等敏感话题的内容。如果您有其他相关问题,欢迎您提问。",
"content": "作为一名人工智能助手,我不能提供涉及色情、暴力、政治等敏感话题的内容。如果您有其他相关问题,欢迎您提问。"
},
"logprobs": null,
"finish_reason": "stop"
"finish_reason": "stop",
"x_higress_guardrail": {
"code": 200,
"denyMessage": "作为一名人工智能助手,我不能提供涉及色情、暴力、政治等敏感话题的内容。如果您有其他相关问题,欢迎您提问。",
"blockedDetails": [
{
"type": "contentModeration",
"level": "high"
}
]
}
}
]
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
```

View File

@@ -34,6 +34,7 @@ Plugin Priority: `300`
| `denyCode` | int | optional | 200 | Response status code when the specified content is illegal |
| `denyMessage` | string | optional | Drainage/non-streaming response in openai format, the answer content is the suggested answer from Alibaba Cloud content security | Response content when the specified content is illegal |
| `protocol` | string | optional | openai | protocol format, `openai` or `original` |
| `openAIDenyResponseFormat` | string | optional | legacy | OpenAI-wrapped deny response format, `legacy` or `structured`. The default `legacy` preserves historical compatibility; `structured` embeds blocking details at `choices[0].x_higress_guardrail` |
| `contentModerationLevelBar` | string | optional | max | contentModeration risk level threshold, `max`, `high`, `medium` or `low` |
| `promptAttackLevelBar` | string | optional | max | promptAttack risk level threshold `max`, `high`, `medium` or `low` |
| `sensitiveDataLevelBar` | string | optional | S4 | sensitiveData risk level threshold, `S4`, `S3`, `S2` or `S1` |
@@ -71,19 +72,18 @@ Risk level explanations for each detection dimension:
### Deny Response Body
When content is blocked, the plugin (`MultiModalGuard` action) returns the following structured JSON object. The location in the response depends on the protocol:
When content is blocked, the plugin (`MultiModalGuard` action) builds the following structured JSON object. `protocol: original`, MCP, and image-generation paths return it directly or indirectly; OpenAI text-generation wrapping keeps the historical response shape by default, and embeds this object only when `openAIDenyResponseFormat: structured` is configured.
```json
{
"code": 200,
"denyMessage": "Sorry, I cannot answer your question.",
"blockedDetails": [
{
"Type": "contentModeration",
"Level": "high",
"Suggestion": "block"
"type": "contentModeration",
"level": "high"
}
],
"requestId": "AAAAAA-BBBB-CCCC-DDDD-EEEEEEE****",
"guardCode": 200
]
}
```
@@ -91,22 +91,26 @@ Field descriptions:
| Field | Type | Description |
| --- | --- | --- |
| `blockedDetails` | array | Details of the triggered blocking dimensions. Synthesised from top-level risk signals when the security service returns no detail entries. |
| `blockedDetails[].Type` | string | Risk type: `contentModeration` / `promptAttack` / `sensitiveData` / `maliciousUrl` / `modelHallucination` |
| `blockedDetails[].Level` | string | Risk level: `high` / `medium` / `low` etc. |
| `blockedDetails[].Suggestion` | string | Action recommended by the security service, usually `block` |
| `requestId` | string | Request ID from the security service, for tracing |
| `guardCode` | int | Business code returned by the security service (not an HTTP status code; `200` indicates a successful check that detected a risk) |
| `code` | int | For `text_generation` (OpenAI wrapping) and `image_generation` paths, this is the HTTP status the gateway returns, sourced from `denyCode` (default `200`). For `protocol=original` and `mcp` paths, this is the business code returned by the security service (`Response.Code`; `200` indicates a successful check that detected a risk). |
| `denyMessage` | string | Human-readable deny text. Always present on OpenAI-wrapping paths, taken from `denyMessage` (defaults to `Sorry, I cannot answer your question.`). On `protocol=original` / `image_generation` / `mcp` paths the value is taken from `denyMessage` and omitted (`omitempty`) when unconfigured. |
| `blockedDetails` | array | Details of the triggered blocking dimensions. Synthesised from top-level `RiskLevel`/`AttackLevel` when the security service returns no `Detail` entries. Returns `[]` when no dimension is hit. |
| `blockedDetails[].type` | string | Risk type: `contentModeration` / `promptAttack` / `sensitiveData` / `maliciousUrl` / `modelHallucination` / `customLabel` |
| `blockedDetails[].level` | string | Risk level: `high` / `medium` / `low`; for sensitive data: `S1``S4` |
> Note: the current implementation emits only the fields above. The security service's `RequestId`, per-detail `Suggestion`, and raw business code (`guardCode`) are not embedded in the deny body. The security service's `RequestId` is exposed via the AI access log field `safecheck_request_ids` (see the AI Log section below).
How the body is embedded per protocol:
- **`text_generation` (OpenAI non-streaming)**: serialised as a JSON string and placed in `choices[0].message.content`
- **`text_generation` (OpenAI streaming SSE)**: same, placed in `delta.content` of the first chunk
- **`text_generation` (`protocol=original`)**: returned directly as the JSON response body
- **`text_generation` (OpenAI, default `legacy`)**: emits neither `x_higress_guardrail` nor the historical `x_higress` field; `choices[0].message.content` / the first `delta.content` frame keeps the historical content shape (a JSON string for RiskBlock, deny text for mask fallback), `finish_reason` is `"stop"`, and streaming responses still end with `data: [DONE]`
- **`text_generation` (OpenAI, `structured` non-streaming)**: `choices[0].message.content` carries the human-readable deny text (`denyMessage`, defaults to `Sorry, I cannot answer your question.` when unconfigured); the structure above is placed at `choices[0].x_higress_guardrail` as an embedded object (not a JSON string)
- **`text_generation` (OpenAI, `structured` streaming SSE)**: the first frame's `delta.content` carries the human-readable deny text; the structure above is attached only to the last chunk at `choices[0].x_higress_guardrail` as an embedded object, followed by `data: [DONE]`
- **`text_generation` (`protocol=original`)**: returned directly as the JSON response body (no OpenAI wrapper, no `x_higress_guardrail`)
- **`image_generation`**: returned directly as the JSON response body (HTTP 403)
- **`mcp` (JSON-RPC)**: serialised as a JSON string and placed in `error.message`
- **`mcp` (SSE)**: same, returned via SSE event
`openAIDenyResponseFormat` only changes the OpenAI-wrapped deny body shape; blocking decisions, fail-open behavior, metrics, and AI Log fields do not vary by format. Configure this field only at plugin global scope, not under `consumerRiskLevel`.
## Examples of configuration
### Check if the input is legal
@@ -131,6 +135,14 @@ checkRequest: true
checkResponse: true
```
### Configure OpenAI Structured Deny Responses
The default `openAIDenyResponseFormat: legacy` keeps the historical response shape. To emit structured blocking details in OpenAI responses, configure:
```yaml
openAIDenyResponseFormat: structured
```
### Configure response fallback extraction paths
When primary extraction paths are empty, you can configure ordered fallback paths to support multiple response formats:
@@ -165,7 +177,57 @@ ai-security-guard plugin provides following metrics:
- `ai_sec_request_deny`: count of requests denied at request phase
- `ai_sec_response_deny`: count of requests denied at response phase
#### Image response-phase metric / ai_log rename (transition window)
The image generation handlers (`lvwang/multi_modal_guard/image/openai.go` and `lvwang/multi_modal_guard/image/qwen.go`) historically emitted request-phase field names for **response-phase** events. This release corrects the semantics and keeps a **double-write transition** for 12 release cycles:
| Signal | Legacy value (wrong; removed in a future release) | New value (recommended) |
| --- | --- | --- |
| Counter (deny) | `ai_sec_request_deny` | `ai_sec_response_deny` |
| ai_log latency (pass + deny) | `safecheck_request_rt` | `safecheck_response_rt` |
| ai_log status (deny) | `safecheck_status="reqeust deny"` (typo; **dropped immediately, no longer emitted**) | `safecheck_status="response deny"` |
During the transition window, the image response phase emits both the new and the legacy `*_deny` counters and `safecheck_*_rt` attributes; `safecheck_status` only emits the new value. Migrate dashboards / alerts to the `response_*` names; any image-response alert that still keys off the typo'd `reqeust deny` status string must move to `response deny` immediately.
### Trace
ai-security-guard plugin provides following span attributes:
- `ai_sec_risklabel`: risk type of this request
- `ai_sec_deny_phase`: denied phase of this request, value can be request/response
- `ai_sec_deny_phase`: denied phase of this request, value can be request/response
### AI Log
ai-security-guard writes each submission to the content security service into the AI access log, so gateway logs can be correlated with Alibaba Cloud content security requests:
| Field | Type | Description |
| --- | --- | --- |
| `safecheck_requests` | array | Submission event array. Each item is `{"requestId"?: string, "phase": string, "modality": string, "result": string}` |
| `safecheck_request_ids` | array | All valid content security `RequestId` values for the current gateway request, preserved in submission completion order without deduplication or truncation |
| `safecheck_request_id` | string | The latest valid content security `RequestId`, kept for consumers that only read a single value |
| `safecheck_status` | string | Legacy compatibility field reflecting the last status transition for this gateway request (see enum below) |
| `safecheck_request_rt` / `safecheck_response_rt` | int | Latency (ms) of the security check during the request / response phase |
| `safecheck_riskLabel` / `safecheck_riskWords` | string | Risk label and risk words when a risk is hit (taken from the first result returned by the security service) |
`safecheck_requests[].phase` is `request` or `response`; `modality` is `text`, `image`, or `mcp`; `result` describes **the processing outcome of that submission event itself** (not the gateway's final outbound action). Values:
| `result` value | Meaning |
| --- | --- |
| `pass` | The submission passed the check |
| `deny` | The submission hit a risk; the gateway returned a deny response |
| `mask` | The submission hit a risk with `Action=Mask`; the security service returned desensitized text and the request body was rewritten |
| `error` | The submission itself failed (HTTP non-200, business `Code` non-200, unmarshal failure, deny-response build failure, dispatch failure, etc.). When the failure occurs in the **streaming response callback** because building the deny response failed, the gateway fails open (injects buffered upstream content as-is); in that case `safecheck_status=build_fallback_pass` and the corresponding event has `result=error` to indicate the security submission did not complete |
The plugin writes `requestId`, `safecheck_request_ids`, and `safecheck_request_id` only when the security service response contains a JSON string `RequestId` and `strings.TrimSpace(RequestId) != ""`; missing, empty, whitespace-only, or non-string values do not produce empty placeholders.
Every submission attempt emits one `safecheck_requests` event, including HTTP non-200 responses, business failures, and failures to dispatch the security service call. These error paths are recorded as `result=error`. Use `safecheck_requests` for precise auditing across multiple submissions, streaming chunks, or multiple image checks.
`safecheck_status` enum (legacy field; overwritten on each status transition, so only the last transition's value is preserved when there are multiple submissions):
| `safecheck_status` value | Meaning |
| --- | --- |
| `request pass` | All request-phase submissions passed |
| `request mask` | A request-phase submission hit mask; the request body was rewritten with desensitized text |
| `reqeust deny` | A request-phase submission hit a risk; the gateway returned a deny response (note: typo `reqeust` is preserved for backward compatibility) |
| `request error` | A request-phase security submission itself failed (HTTP / unmarshal / dispatch / etc.); the gateway fails open |
| `response pass` | All response-phase submissions passed |
| `response deny` | A response-phase submission hit a risk; the gateway returned a deny response |
| `response error` | A response-phase security submission itself failed; the gateway fails open |
| `build_fallback_pass` | In the streaming response callback, building the deny response failed; the gateway fails open and injects the buffered upstream content as-is |

View File

@@ -0,0 +1,454 @@
package main
import (
"encoding/json"
"testing"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/iface"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
type aiLogSnapshot struct {
SafecheckRequests []cfg.GuardrailSubmissionEvent `json:"safecheck_requests"`
SafecheckRequestIDs []string `json:"safecheck_request_ids"`
SafecheckRequestID string `json:"safecheck_request_id"`
SafecheckStatus string `json:"safecheck_status"`
}
func readAILogSnapshot(t *testing.T, host test.TestHost) (aiLogSnapshot, string) {
t.Helper()
raw, err := host.GetProperty([]string{wrapper.AILogKey})
require.NoError(t, err)
decoded := wrapper.UnmarshalStr(`"` + string(raw) + `"`)
require.NotEmpty(t, decoded)
var snapshot aiLogSnapshot
require.NoError(t, json.Unmarshal([]byte(decoded), &snapshot))
return snapshot, decoded
}
func requireAILogArraySchema(t *testing.T, raw string) {
t.Helper()
require.True(t, gjson.Get(raw, cfg.SafecheckRequestsKey).IsArray(), "safecheck_requests must be a JSON array")
require.True(t, gjson.Get(raw, cfg.SafecheckRequestIDsKey).IsArray(), "safecheck_request_ids must be a JSON array")
}
func requireSafecheckEvent(t *testing.T, event cfg.GuardrailSubmissionEvent, phase, modality, result, requestID string) {
t.Helper()
require.Equal(t, phase, event.Phase)
require.Equal(t, modality, event.Modality)
require.Equal(t, result, event.Result)
require.Equal(t, requestID, event.RequestID)
}
func TestGuardrailAILogRequestAndResponseEventSchema(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("request pass emits one structured text event", func(t *testing.T) {
host, status := test.NewTestHost(multiModalGuardTextConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "Hello"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-structured-pass", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
snapshot, raw := readAILogSnapshot(t, host)
requireAILogArraySchema(t, raw)
require.Len(t, snapshot.SafecheckRequests, 1)
requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText, cfg.GuardrailResultPass, "req-structured-pass")
require.Equal(t, []string{"req-structured-pass"}, snapshot.SafecheckRequestIDs)
require.Equal(t, "req-structured-pass", snapshot.SafecheckRequestID)
require.Equal(t, "request pass", snapshot.SafecheckStatus)
})
t.Run("response deny emits one structured text event", func(t *testing.T) {
host, status := test.NewTestHost(multiModalGuardTextConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
body := `{"choices": [{"message": {"role": "assistant", "content": "bad response content"}}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpResponseBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-structured-deny", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
snapshot, raw := readAILogSnapshot(t, host)
requireAILogArraySchema(t, raw)
require.Len(t, snapshot.SafecheckRequests, 1)
requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseResponse, cfg.GuardrailModalityText, cfg.GuardrailResultDeny, "req-structured-deny")
require.Equal(t, []string{"req-structured-deny"}, snapshot.SafecheckRequestIDs)
require.Equal(t, "req-structured-deny", snapshot.SafecheckRequestID)
require.Equal(t, "response deny", snapshot.SafecheckStatus)
})
})
}
func TestGuardrailAILogStreamingPassFlushesBeforeEOS(t *testing.T) {
streamingFlushConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": false,
"checkResponse": true,
"action": "MultiModalGuard",
"apiType": "text_generation",
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"timeout": 2000,
"bufferLimit": 1,
})
return data
}()
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(streamingFlushConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "text/event-stream"},
})
chunk := []byte("data:{\"id\":\"chatcmpl-1\",\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\n")
host.CallOnHttpStreamingResponseBody(chunk, false)
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-stream-pass", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
snapshot, raw := readAILogSnapshot(t, host)
requireAILogArraySchema(t, raw)
require.Len(t, snapshot.SafecheckRequests, 1)
requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseResponse, cfg.GuardrailModalityText, cfg.GuardrailResultPass, "req-stream-pass")
require.False(t, gjson.Get(raw, "safecheck_status").Exists(), "event-level flush should not wait for a terminal safecheck_status")
})
}
func TestGuardrailAILogErrorFlushAndOrderingForImageSubmissions(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
runCase := func(t *testing.T, firstHeaders [][2]string, firstResponse, firstRequestID string) {
host, status := test.NewTestHost(multiModalGuardImageQwenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
})
body := `{"input": {"images": ["https://example.com/a.png", "https://example.com/b.png"]}}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
host.CallOnHttpCall(firstHeaders, []byte(firstResponse))
snapshot, raw := readAILogSnapshot(t, host)
requireAILogArraySchema(t, raw)
require.Len(t, snapshot.SafecheckRequests, 1)
requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityImage, cfg.GuardrailResultError, firstRequestID)
require.Equal(t, []string{firstRequestID}, snapshot.SafecheckRequestIDs)
secondResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-image-pass", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(secondResponse))
snapshot, raw = readAILogSnapshot(t, host)
requireAILogArraySchema(t, raw)
require.Len(t, snapshot.SafecheckRequests, 2)
requireSafecheckEvent(t, snapshot.SafecheckRequests[1], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityImage, cfg.GuardrailResultPass, "req-image-pass")
require.Equal(t, []string{firstRequestID, "req-image-pass"}, snapshot.SafecheckRequestIDs)
require.Equal(t, "req-image-pass", snapshot.SafecheckRequestID)
}
t.Run("non-200 HTTP response flushes error before next image submission", func(t *testing.T) {
runCase(t, [][2]string{
{":status", "502"},
{"content-type", "application/json"},
}, `{"RequestId": "req-http-error"}`, "req-http-error")
})
t.Run("business failure flushes error before next image submission", func(t *testing.T) {
runCase(t, [][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, `{"Code": 500, "Message": "Failed", "RequestId": "req-business-error"}`, "req-business-error")
})
})
}
func TestGuardrailAILogMalformedRequestIDsAreIgnored(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
cases := []struct {
name string
response string
expectedResult string
}{
{
name: "missing",
response: `{"Code": 200, "Message": "Success", "Data": {"RiskLevel": "low"}}`,
expectedResult: cfg.GuardrailResultPass,
},
{
name: "empty",
response: `{"Code": 200, "Message": "Success", "RequestId": "", "Data": {"RiskLevel": "low"}}`,
expectedResult: cfg.GuardrailResultPass,
},
{
name: "whitespace",
response: `{"Code": 200, "Message": "Success", "RequestId": " ", "Data": {"RiskLevel": "low"}}`,
expectedResult: cfg.GuardrailResultPass,
},
{
name: "non-string",
response: `{"Code": 200, "Message": "Success", "RequestId": 123, "Data": {"RiskLevel": "low"}}`,
expectedResult: cfg.GuardrailResultError,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
host, status := test.NewTestHost(multiModalGuardTextConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "Hello"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(tc.response))
snapshot, raw := readAILogSnapshot(t, host)
requireAILogArraySchema(t, raw)
require.Len(t, snapshot.SafecheckRequests, 1)
requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText, tc.expectedResult, "")
require.Empty(t, snapshot.SafecheckRequestIDs)
require.False(t, gjson.Get(raw, cfg.SafecheckRequestIDKey).Exists())
require.False(t, gjson.Get(raw, cfg.SafecheckRequestsKey+".0.requestId").Exists())
})
}
})
}
func TestGuardrailAILogMaskFallbackRecordsDeny(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(maskConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "敏感内容"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{
"Code": 200, "Message": "Success", "RequestId": "req-mask-fallback",
"Data": {
"RiskLevel": "none",
"Detail": [{
"Suggestion": "mask", "Type": "sensitiveData", "Level": "S3",
"Result": [{"Label": "phone", "Confidence": 99.0,
"Ext": {"Desensitization": ""}}]
}]
}
}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
snapshot, raw := readAILogSnapshot(t, host)
requireAILogArraySchema(t, raw)
require.Len(t, snapshot.SafecheckRequests, 1)
requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText, cfg.GuardrailResultDeny, "req-mask-fallback")
require.Equal(t, []string{"req-mask-fallback"}, snapshot.SafecheckRequestIDs)
})
}
func TestGuardrailAILogDispatchFailureEmitsErrorEvent(t *testing.T) {
ctx := newStubHTTPContext()
eventIndex := cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText)
cfg.CompleteGuardrailSubmissionEventWithRequestID(ctx, eventIndex, "", cfg.GuardrailResultError)
cfg.WriteGuardrailLog(ctx)
events, ok := ctx.GetUserAttribute(cfg.SafecheckRequestsKey).([]cfg.GuardrailSubmissionEvent)
require.True(t, ok)
require.Len(t, events, 1)
requireSafecheckEvent(t, events[0], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText, cfg.GuardrailResultError, "")
requestIDs, ok := ctx.GetUserAttribute(cfg.SafecheckRequestIDsKey).([]string)
require.True(t, ok)
require.Empty(t, requestIDs)
require.Nil(t, ctx.GetUserAttribute(cfg.SafecheckRequestIDKey))
require.Equal(t, []string{wrapper.AILogKey}, ctx.writes)
}
type stubHTTPContext struct {
userContext map[string]interface{}
userAttribute map[string]interface{}
bufferQueue [][]byte
writes []string
routeCallError error
}
func newStubHTTPContext() *stubHTTPContext {
return &stubHTTPContext{
userContext: map[string]interface{}{},
userAttribute: map[string]interface{}{},
}
}
func (ctx *stubHTTPContext) Scheme() string { return "" }
func (ctx *stubHTTPContext) Host() string { return "" }
func (ctx *stubHTTPContext) Path() string { return "" }
func (ctx *stubHTTPContext) Method() string { return "" }
func (ctx *stubHTTPContext) SetContext(key string, value interface{}) {
ctx.userContext[key] = value
}
func (ctx *stubHTTPContext) GetContext(key string) interface{} {
return ctx.userContext[key]
}
func (ctx *stubHTTPContext) GetBoolContext(key string, defaultValue bool) bool {
if value, ok := ctx.userContext[key].(bool); ok {
return value
}
return defaultValue
}
func (ctx *stubHTTPContext) GetStringContext(key, defaultValue string) string {
if value, ok := ctx.userContext[key].(string); ok {
return value
}
return defaultValue
}
func (ctx *stubHTTPContext) GetByteSliceContext(key string, defaultValue []byte) []byte {
if value, ok := ctx.userContext[key].([]byte); ok {
return value
}
return defaultValue
}
func (ctx *stubHTTPContext) GetUserAttribute(key string) interface{} {
return ctx.userAttribute[key]
}
func (ctx *stubHTTPContext) SetUserAttribute(key string, value interface{}) {
ctx.userAttribute[key] = value
}
func (ctx *stubHTTPContext) SetUserAttributeMap(kvmap map[string]interface{}) {
ctx.userAttribute = kvmap
}
func (ctx *stubHTTPContext) GetUserAttributeMap() map[string]interface{} {
return ctx.userAttribute
}
func (ctx *stubHTTPContext) WriteUserAttributeToLog() error {
return ctx.WriteUserAttributeToLogWithKey(wrapper.CustomLogKey)
}
func (ctx *stubHTTPContext) WriteUserAttributeToLogWithKey(key string) error {
ctx.writes = append(ctx.writes, key)
return nil
}
func (ctx *stubHTTPContext) WriteUserAttributeToTrace() error { return nil }
func (ctx *stubHTTPContext) DontReadRequestBody() {}
func (ctx *stubHTTPContext) DontReadResponseBody() {}
func (ctx *stubHTTPContext) BufferRequestBody() {}
func (ctx *stubHTTPContext) BufferResponseBody() {}
func (ctx *stubHTTPContext) NeedPauseStreamingResponse() {}
func (ctx *stubHTTPContext) PushBuffer(buffer []byte) {
ctx.bufferQueue = append(ctx.bufferQueue, buffer)
}
func (ctx *stubHTTPContext) PopBuffer() []byte {
if len(ctx.bufferQueue) == 0 {
return nil
}
buffer := ctx.bufferQueue[0]
ctx.bufferQueue = ctx.bufferQueue[1:]
return buffer
}
func (ctx *stubHTTPContext) BufferQueueSize() int { return len(ctx.bufferQueue) }
func (ctx *stubHTTPContext) DisableReroute() {}
func (ctx *stubHTTPContext) SetRequestBodyBufferLimit(uint32) {
}
func (ctx *stubHTTPContext) SetResponseBodyBufferLimit(uint32) {
}
func (ctx *stubHTTPContext) RouteCall(string, string, [][2]string, []byte, iface.RouteResponseCallback) error {
return ctx.routeCallError
}
func (ctx *stubHTTPContext) GetExecutionPhase() iface.HTTPExecutionPhase {
return iface.DecodeData
}
func (ctx *stubHTTPContext) HasRequestBody() bool { return true }
func (ctx *stubHTTPContext) HasResponseBody() bool { return true }
func (ctx *stubHTTPContext) IsWebsocket() bool { return false }
func (ctx *stubHTTPContext) IsBinaryRequestBody() bool { return false }
func (ctx *stubHTTPContext) IsBinaryResponseBody() bool {
return false
}

View File

@@ -0,0 +1,126 @@
package config
import (
"strings"
"time"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
SafecheckRequestsKey = "safecheck_requests"
SafecheckRequestIDsKey = "safecheck_request_ids"
SafecheckRequestIDKey = "safecheck_request_id"
GuardrailPhaseRequest = "request"
GuardrailPhaseResponse = "response"
GuardrailModalityText = "text"
GuardrailModalityImage = "image"
GuardrailModalityMCP = "mcp"
GuardrailResultPass = "pass"
GuardrailResultDeny = "deny"
GuardrailResultMask = "mask"
GuardrailResultError = "error"
)
type GuardrailSubmissionEvent struct {
RequestID string `json:"requestId,omitempty"`
Phase string `json:"phase"`
Modality string `json:"modality"`
Result string `json:"result"`
}
// BeginGuardrailSubmissionEvent appends a placeholder event so append order matches
// the current serial submission order. The event is flushed only after completion.
func BeginGuardrailSubmissionEvent(ctx wrapper.HttpContext, phase, modality string) int {
events := getGuardrailSubmissionEvents(ctx)
events = append(events, GuardrailSubmissionEvent{
Phase: phase,
Modality: modality,
})
setGuardrailSubmissionEvents(ctx, events)
return len(events) - 1
}
func CompleteGuardrailSubmissionEvent(ctx wrapper.HttpContext, index int, responseBody []byte, result string) {
CompleteGuardrailSubmissionEventWithRequestID(ctx, index, ExtractValidRequestID(responseBody), result)
}
func CompleteGuardrailSubmissionEventWithRequestID(ctx wrapper.HttpContext, index int, requestID, result string) {
events := getGuardrailSubmissionEvents(ctx)
if index < 0 || index >= len(events) {
return
}
events[index].Result = result
if requestID != "" {
events[index].RequestID = requestID
}
setGuardrailSubmissionEvents(ctx, events)
}
// WriteGuardrailLog writes current guardrail-related user attributes to the AI log.
// Call after submission events are updated; Complete* does not flush the log.
func WriteGuardrailLog(ctx wrapper.HttpContext) {
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
// MarkGuardrailRequestError finalizes a request-phase safecheck submission that failed
// upstream (HTTP, unmarshal, downstream build, or dispatch error). It overwrites the
// legacy safecheck_status with "request error" so consumers that only watch the single
// status field do not see a stale "request pass" left over from a prior chunk, and
// records safecheck_request_rt so latency metrics cover failures too.
func MarkGuardrailRequestError(ctx wrapper.HttpContext, index int, responseBody []byte, startTime int64) {
ctx.SetUserAttribute("safecheck_request_rt", time.Now().UnixMilli()-startTime)
ctx.SetUserAttribute("safecheck_status", "request error")
CompleteGuardrailSubmissionEvent(ctx, index, responseBody, GuardrailResultError)
WriteGuardrailLog(ctx)
}
// MarkGuardrailResponseError is the response-phase counterpart of MarkGuardrailRequestError.
func MarkGuardrailResponseError(ctx wrapper.HttpContext, index int, responseBody []byte, startTime int64) {
ctx.SetUserAttribute("safecheck_response_rt", time.Now().UnixMilli()-startTime)
ctx.SetUserAttribute("safecheck_status", "response error")
CompleteGuardrailSubmissionEvent(ctx, index, responseBody, GuardrailResultError)
WriteGuardrailLog(ctx)
}
func ExtractValidRequestID(responseBody []byte) string {
if len(responseBody) == 0 {
return ""
}
requestID := gjson.GetBytes(responseBody, "RequestId")
if !requestID.Exists() || requestID.Type != gjson.String {
return ""
}
trimmed := strings.TrimSpace(requestID.String())
if trimmed == "" {
return ""
}
return trimmed
}
func getGuardrailSubmissionEvents(ctx wrapper.HttpContext) []GuardrailSubmissionEvent {
events, ok := ctx.GetUserAttribute(SafecheckRequestsKey).([]GuardrailSubmissionEvent)
if !ok || events == nil {
return []GuardrailSubmissionEvent{}
}
return events
}
func setGuardrailSubmissionEvents(ctx wrapper.HttpContext, events []GuardrailSubmissionEvent) {
ctx.SetUserAttribute(SafecheckRequestsKey, events)
requestIDs := make([]string, 0, len(events))
for _, event := range events {
if event.RequestID != "" {
requestIDs = append(requestIDs, event.RequestID)
}
}
ctx.SetUserAttribute(SafecheckRequestIDsKey, requestIDs)
if len(requestIDs) > 0 {
ctx.SetUserAttribute(SafecheckRequestIDKey, requestIDs[len(requestIDs)-1])
}
}

View File

@@ -6,12 +6,16 @@ import (
"fmt"
"regexp"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
type OpenAIDenyResponseFormat string
const (
MaxRisk = "max"
HighRisk = "high"
@@ -35,10 +39,31 @@ const (
WaterMarkType = "waterMark"
// Default configurations
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
// Template parameter order:
// OpenAIResponseFormatLegacy: id, created (unix sec), content
// OpenAIResponseFormatStructured: id, created (unix sec), content, x_higress_guardrail JSON
// OpenAIStreamResponseChunk: id, created, content
// OpenAIStreamResponseEndLegacy: id, created
// OpenAIStreamResponseEndStructured: id, created, x_higress_guardrail JSON
// OpenAIStreamResponseFormatLegacy: id, created, content, id, created
// OpenAIStreamResponseFormatStructured: id, created, content, id, created, x_higress_guardrail JSON
// `created` is required by openai-python (ChatCompletion.created is non-Optional).
// `finish_reason: "stop"` preserves wire-level compatibility with downstream
// consumers (LangChain / LiteLLM / SDKs / BI dashboards) that treat `stop` as
// "valid completion"; the moderation-event signal lives in the nested
// `choices[0].x_higress_guardrail` block (denyCode / blockedDetails) instead.
OpenAIResponseFormatLegacy = `{"id":"%s","object":"chat.completion","created":%d,"model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIResponseFormatStructured = `{"id":"%s","object":"chat.completion","created":%d,"model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop","x_higress_guardrail":%s}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEndLegacy = `data:{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseEndStructured = `data:{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop","x_higress_guardrail":%s}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseFormatLegacy = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEndLegacy + "\n\n" + `data: [DONE]`
OpenAIStreamResponseFormatStructured = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEndStructured + "\n\n" + `data: [DONE]`
OpenAIDenyResponseFormatLegacy OpenAIDenyResponseFormat = "legacy"
OpenAIDenyResponseFormatStructured OpenAIDenyResponseFormat = "structured"
OpenAIDenyResponseFormatConsumerScopeError = "openAIDenyResponseFormat must be configured at plugin global scope, not under consumerRiskLevel"
DefaultDenyCode = 200
DefaultDenyMessage = "很抱歉,我无法回答您的问题"
@@ -184,6 +209,7 @@ type AISecurityConfig struct {
DenyCode int64
DenyMessage string
ProtocolOriginal bool
OpenAIDenyResponseFormat OpenAIDenyResponseFormat
RiskLevelBar string
ContentModerationLevelBar string
PromptAttackLevelBar string
@@ -296,6 +322,16 @@ func (config *AISecurityConfig) Parse(json gjson.Result) error {
config.CheckRequestImage = json.Get("checkRequestImage").Bool()
config.CheckResponse = json.Get("checkResponse").Bool()
config.ProtocolOriginal = json.Get("protocol").String() == "original"
if obj := json.Get("openAIDenyResponseFormat"); obj.Exists() {
switch OpenAIDenyResponseFormat(obj.String()) {
case OpenAIDenyResponseFormatLegacy:
config.OpenAIDenyResponseFormat = OpenAIDenyResponseFormatLegacy
case OpenAIDenyResponseFormatStructured:
config.OpenAIDenyResponseFormat = OpenAIDenyResponseFormatStructured
default:
return errors.New("invalid openAIDenyResponseFormat, value must be one of [legacy, structured]")
}
}
config.DenyMessage = json.Get("denyMessage").String()
if obj := json.Get("denyCode"); obj.Exists() {
config.DenyCode = obj.Int()
@@ -411,6 +447,9 @@ func (config *AISecurityConfig) Parse(json gjson.Result) error {
for k, v := range item.Map() {
m[k] = v.Value()
}
if _, ok := m["openAIDenyResponseFormat"]; ok {
return errors.New(OpenAIDenyResponseFormatConsumerScopeError)
}
consumerName, ok1 := m["name"]
matchType, ok2 := m["matchType"]
if !ok1 || !ok2 {
@@ -531,6 +570,7 @@ func (config *AISecurityConfig) SetDefaultValues() {
config.ApiType = ApiTextGeneration
config.ProviderType = ProviderOpenAI
config.RiskAction = "block"
config.OpenAIDenyResponseFormat = OpenAIDenyResponseFormatLegacy
}
func (config *AISecurityConfig) IncrementCounter(metricName string, inc uint64) {
@@ -951,6 +991,151 @@ func BuildDenyResponseBody(response Response, config AISecurityConfig, consumer
return json.Marshal(body)
}
// ResolveDenyMessage returns the human-readable deny text used both for
// non-original OpenAI wrappers (message.content / delta.content) and for the
// x_higress_guardrail.denyMessage field, ensuring the two stay aligned.
func ResolveDenyMessage(config AISecurityConfig) string {
if config.DenyMessage != "" {
return config.DenyMessage
}
return DefaultDenyMessage
}
// BuildOpenAIDenyResponseBody builds the guardrail JSON object embedded by the
// outer OpenAI structured template as choices[0].x_higress_guardrail. Its shape
// mirrors DenyResponseBody, but DenyMessage is filled via ResolveDenyMessage so
// the field is always present and consistent with the rendered content.
// Code is sourced from config.DenyCode so that x_higress_guardrail.code
// consistently represents "the HTTP status this gateway returns to the client",
// aligned with BuildOpenAIFallbackDenyResponseBody.
func BuildOpenAIDenyResponseBody(response Response, config AISecurityConfig, consumer string) ([]byte, error) {
details := GetUnacceptableDetail(response.Data, config, consumer)
blocked := make([]BlockedDetail, 0, len(details))
for _, d := range details {
blocked = append(blocked, BlockedDetail{
Type: d.Type,
Level: d.Level,
})
}
body := DenyResponseBody{
Code: int(config.DenyCode),
DenyMessage: ResolveDenyMessage(config),
BlockedDetails: blocked,
}
return json.Marshal(body)
}
// BuildOpenAIFallbackDenyResponseBody builds the guardrail JSON object embedded
// by the outer OpenAI structured template as choices[0].x_higress_guardrail for
// mask→block fallback paths. The fallback is triggered by
// ReplaceJsonFieldTextContent failure or empty desensitization, so there is no
// upstream Response object to derive blockedDetails from.
func BuildOpenAIFallbackDenyResponseBody(config AISecurityConfig) ([]byte, error) {
body := DenyResponseBody{
Code: int(config.DenyCode),
DenyMessage: ResolveDenyMessage(config),
BlockedDetails: []BlockedDetail{},
}
return json.Marshal(body)
}
func openAIDenyContentType(isStream bool) [][2]string {
if isStream {
return [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}
}
return [][2]string{{"content-type", "application/json"}}
}
// BuildOpenAIDenyData builds the complete OpenAI-formatted deny response bytes
// (structured or legacy). Callers that need raw bytes (e.g. streaming response
// handlers using InjectEncodedDataToFilterChain) use this directly; callers that
// want a full SendHttpResponse dispatch should use SendDenyResponse instead.
func BuildOpenAIDenyData(config AISecurityConfig, response Response, consumer string, isStream bool) ([]byte, error) {
if config.OpenAIDenyResponseFormat == OpenAIDenyResponseFormatStructured {
guardrailBody, err := BuildOpenAIDenyResponseBody(response, config, consumer)
if err != nil {
return nil, err
}
marshalledDenyMessage := wrapper.MarshalStr(ResolveDenyMessage(config))
randomID := utils.GenerateRandomChatID()
createdTs := time.Now().Unix()
if isStream {
return []byte(fmt.Sprintf(OpenAIStreamResponseFormatStructured, randomID, createdTs, marshalledDenyMessage, randomID, createdTs, string(guardrailBody))), nil
}
return []byte(fmt.Sprintf(OpenAIResponseFormatStructured, randomID, createdTs, marshalledDenyMessage, string(guardrailBody))), nil
}
denyBody, err := BuildDenyResponseBody(response, config, consumer)
if err != nil {
return nil, err
}
marshalledDenyBody := wrapper.MarshalStr(string(denyBody))
randomID := utils.GenerateRandomChatID()
createdTs := time.Now().Unix()
if isStream {
return []byte(fmt.Sprintf(OpenAIStreamResponseFormatLegacy, randomID, createdTs, marshalledDenyBody, randomID, createdTs)), nil
}
return []byte(fmt.Sprintf(OpenAIResponseFormatLegacy, randomID, createdTs, marshalledDenyBody)), nil
}
// SendDenyResponse dispatches a deny HTTP response in the appropriate format
// (ProtocolOriginal, Structured, or Legacy). It returns an error only if
// building the response body fails; the caller should handle the error
// (e.g. log, mark guardrail error, resume request/response).
func SendDenyResponse(config AISecurityConfig, response Response, consumer string, isStream bool) error {
if config.ProtocolOriginal {
denyBody, err := BuildDenyResponseBody(response, config, consumer)
if err != nil {
return err
}
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1)
return nil
}
data, err := BuildOpenAIDenyData(config, response, consumer, isStream)
if err != nil {
return err
}
proxywasm.SendHttpResponse(uint32(config.DenyCode), openAIDenyContentType(isStream), data, -1)
return nil
}
// SendFallbackDenyResponse dispatches a fallback deny HTTP response when no
// upstream Response object is available (e.g. mask-to-block on replace error
// or empty desensitization). For ProtocolOriginal it sends the plain deny
// message; for OpenAI formats it wraps it in the appropriate template with
// an empty-blockedDetails guardrail object (structured) or the deny message
// as content (legacy).
func SendFallbackDenyResponse(config AISecurityConfig, isStream bool) error {
marshalledDenyMessage := wrapper.MarshalStr(ResolveDenyMessage(config))
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
return nil
}
randomID := utils.GenerateRandomChatID()
createdTs := time.Now().Unix()
if config.OpenAIDenyResponseFormat == OpenAIDenyResponseFormatStructured {
guardrailBody, err := BuildOpenAIFallbackDenyResponseBody(config)
if err != nil {
return err
}
var data []byte
if isStream {
data = []byte(fmt.Sprintf(OpenAIStreamResponseFormatStructured, randomID, createdTs, marshalledDenyMessage, randomID, createdTs, string(guardrailBody)))
} else {
data = []byte(fmt.Sprintf(OpenAIResponseFormatStructured, randomID, createdTs, marshalledDenyMessage, string(guardrailBody)))
}
proxywasm.SendHttpResponse(uint32(config.DenyCode), openAIDenyContentType(isStream), data, -1)
return nil
}
var data []byte
if isStream {
data = []byte(fmt.Sprintf(OpenAIStreamResponseFormatLegacy, randomID, createdTs, marshalledDenyMessage, randomID, createdTs))
} else {
data = []byte(fmt.Sprintf(OpenAIResponseFormatLegacy, randomID, createdTs, marshalledDenyMessage))
}
proxywasm.SendHttpResponse(uint32(config.DenyCode), openAIDenyContentType(isStream), data, -1)
return nil
}
func GetUnacceptableDetail(data Data, config AISecurityConfig, consumer string) []Detail {
result := []Detail{}
for _, detail := range data.Detail {

View File

@@ -3,7 +3,6 @@ package text
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
@@ -21,6 +20,7 @@ import (
const (
responseFallbackPathsCtxKey = "response_fallback_paths"
responseStreamFallbackPathsCtxKey = "response_stream_fallback_paths"
responseStartTimeCtxKey = "response_start_time"
)
func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
@@ -28,6 +28,7 @@ func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISe
ctx.SetContext("end_of_stream_received", false)
ctx.SetContext("during_call", false)
ctx.SetContext("risk_detected", false)
ctx.SetContext(responseStartTimeCtxKey, time.Now().UnixMilli())
ctx.SetContext(responseFallbackPathsCtxKey, buildEffectiveFallbackPaths(config.ResponseContentJsonPath, config.ResponseContentFallbackJsonPaths))
ctx.SetContext(responseStreamFallbackPathsCtxKey, buildEffectiveFallbackPaths(config.ResponseStreamContentJsonPath, config.ResponseStreamContentFallbackJsonPaths))
sessionID, _ := utils.GenerateHexID(20)
@@ -52,10 +53,13 @@ func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config c
sessionID, _ = ctx.GetContext("sessionID").(string)
}
var bufferQueue [][]byte
currentSubmissionIndex := 0
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
startTime, _ := ctx.GetContext(responseStartTimeCtxKey).(int64)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
if ctx.GetContext("end_of_stream_received").(bool) {
proxywasm.ResumeHttpResponse()
}
@@ -66,6 +70,8 @@ func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config c
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at response phase")
startTime, _ := ctx.GetContext(responseStartTimeCtxKey).(int64)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
if ctx.GetContext("end_of_stream_received").(bool) {
proxywasm.ResumeHttpResponse()
}
@@ -73,24 +79,50 @@ func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config c
return
}
if !cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
jsonData, err := cfg.BuildOpenAIDenyData(config, response, consumer, true)
if err != nil {
// Build failure → fail-open: inject the buffered upstream content as-is.
// Make this path observable so operators can spot the silent passthrough
// instead of mistakenly attributing observed denies-only to the success
// path's metrics. Symmetric with the success path's observability suite
// (counter / safecheck_response_rt / safecheck_status / log / risk_detected).
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultError)
log.Errorf("failed to build deny response body: %v", err)
endStream := ctx.GetContext("end_of_stream_received").(bool) && ctx.BufferQueueSize() == 0
proxywasm.InjectEncodedDataToFilterChain(bytes.Join(bufferQueue, []byte("")), endStream)
bufferQueue = [][]byte{}
config.IncrementCounter("ai_sec_response_deny_buildfail", 1)
startTime, _ := ctx.GetContext(responseStartTimeCtxKey).(int64)
ctx.SetUserAttribute("safecheck_response_rt", time.Now().UnixMilli()-startTime)
ctx.SetUserAttribute("safecheck_status", "build_fallback_pass")
if len(response.Data.Result) > 0 {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
if !endStream {
ctx.SetContext("during_call", false)
singleCall()
}
return
}
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
proxywasm.InjectEncodedDataToFilterChain(jsonData, true)
ctx.SetContext("risk_detected", true)
ctx.SetContext("during_call", false)
config.IncrementCounter("ai_sec_response_deny", 1)
startTime, _ := ctx.GetContext(responseStartTimeCtxKey).(int64)
ctx.SetUserAttribute("safecheck_response_rt", time.Now().UnixMilli()-startTime)
ctx.SetUserAttribute("safecheck_status", "response deny")
if len(response.Data.Result) > 0 {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
return
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass)
cfg.WriteGuardrailLog(ctx)
endStream := ctx.GetContext("end_of_stream_received").(bool) && ctx.BufferQueueSize() == 0
proxywasm.InjectEncodedDataToFilterChain(bytes.Join(bufferQueue, []byte("")), endStream)
bufferQueue = [][]byte{}
@@ -126,10 +158,13 @@ func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config c
ctx.SetContext("during_call", true)
log.Debugf("current content piece: %s", buffer)
checkService := config.GetResponseCheckService(consumer)
currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseResponse, cfg.GuardrailModalityText)
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, buffer, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
startTime, _ := ctx.GetContext(responseStartTimeCtxKey).(int64)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, nil, startTime)
if ctx.GetContext("end_of_stream_received").(bool) {
proxywasm.ResumeHttpResponse()
}
@@ -198,10 +233,12 @@ func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecu
}
contentIndex := 0
sessionID, _ := utils.GenerateHexID(20)
currentSubmissionIndex := 0
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpResponse()
return
}
@@ -209,6 +246,7 @@ func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecu
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at response phase")
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpResponse()
return
}
@@ -217,41 +255,32 @@ func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecu
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "response pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass)
if contentIndex >= len(content) {
cfg.WriteGuardrailLog(ctx)
proxywasm.ResumeHttpResponse()
} else {
singleCall()
}
return
}
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
if err != nil {
if err := cfg.SendDenyResponse(config, response, consumer, isStreamingResponse); err != nil {
log.Errorf("failed to build deny response body: %v", err)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpResponse()
return
}
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1)
} else if isStreamingResponse {
randomID := utils.GenerateRandomChatID()
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
config.IncrementCounter("ai_sec_response_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "response deny")
if response.Data.Advice != nil {
if len(response.Data.Result) > 0 {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
}
singleCall = func() {
var nextContentIndex int
@@ -264,10 +293,12 @@ func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecu
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
checkService := config.GetResponseCheckService(consumer)
currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseResponse, cfg.GuardrailModalityText)
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, contentPiece, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, nil, startTime)
proxywasm.ResumeHttpResponse()
}
}

View File

@@ -54,11 +54,14 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.
contentIndex := 0
imageIndex := 0
sessionID, _ := utils.GenerateHexID(20)
currentSubmissionIndex := 0
currentImageSubmissionIndex := 0
var singleCall func()
var singleCallForImage func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
@@ -66,6 +69,7 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
@@ -74,10 +78,13 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass)
if contentIndex >= len(content) {
if len(images) > 0 && config.CheckRequestImage {
singleCallForImage()
} else {
cfg.WriteGuardrailLog(ctx)
proxywasm.ResumeHttpRequest()
}
} else {
@@ -88,6 +95,7 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
if err != nil {
log.Errorf("failed to build deny response body: %v", err)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
@@ -97,13 +105,15 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
if len(response.Data.Result) > 0 {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
}
singleCall = func() {
currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText)
var nextContentIndex int
if contentIndex+cfg.LengthLimit >= len(content) {
nextContentIndex = len(content)
@@ -117,6 +127,7 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, nil, startTime)
proxywasm.ResumeHttpRequest()
}
}
@@ -125,6 +136,7 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.
imageIndex += 1
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime)
if imageIndex < len(images) {
singleCallForImage()
} else {
@@ -136,6 +148,7 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime)
if imageIndex < len(images) {
singleCallForImage()
} else {
@@ -148,7 +161,10 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.
if imageIndex >= len(images) {
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentImageSubmissionIndex, responseBody, cfg.GuardrailResultPass)
if imageIndex >= len(images) {
cfg.WriteGuardrailLog(ctx)
proxywasm.ResumeHttpRequest()
} else {
singleCallForImage()
@@ -159,6 +175,7 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
if err != nil {
log.Errorf("failed to build deny response body: %v", err)
cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
@@ -167,13 +184,15 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
if len(response.Data.Result) > 0 {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
cfg.CompleteGuardrailSubmissionEvent(ctx, currentImageSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
}
singleCallForImage = func() {
currentImageSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityImage)
img := images[imageIndex]
imgUrl := ""
imgBase64 := ""
@@ -186,6 +205,7 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, nil, startTime)
proxywasm.ResumeHttpRequest()
}
}
@@ -207,11 +227,13 @@ func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg
return types.ActionContinue
}
imageIndex := 0
currentSubmissionIndex := 0
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
imageIndex += 1
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
if imageIndex < len(imgResults) {
singleCall()
} else {
@@ -223,6 +245,7 @@ func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
if imageIndex < len(imgResults) {
singleCall()
} else {
@@ -233,9 +256,16 @@ func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg
endTime := time.Now().UnixMilli()
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if imageIndex >= len(imgResults) {
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
// Transitional double-write: this handler historically wrote safecheck_request_rt
// for response-phase emissions; existing dashboards key off the old name.
// Remove after 12 release cycles.
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
ctx.SetUserAttribute("safecheck_status", "response pass")
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass)
if imageIndex >= len(imgResults) {
cfg.WriteGuardrailLog(ctx)
proxywasm.ResumeHttpResponse()
} else {
singleCall()
@@ -245,16 +275,25 @@ func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
if err != nil {
log.Errorf("failed to build deny response body: %v", err)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpResponse()
return
}
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, denyBody, -1)
config.IncrementCounter("ai_sec_response_deny", 1)
// Transitional double-write: this handler historically incremented ai_sec_request_deny
// and wrote safecheck_request_rt for response-phase denies; existing dashboards/alerts
// key off the old names. The legacy safecheck_status="reqeust deny" (typo) is dropped
// in this transition. Remove after 12 release cycles.
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
ctx.SetUserAttribute("safecheck_status", "response deny")
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
}
singleCall = func() {
currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseResponse, cfg.GuardrailModalityImage)
img := imgResults[imageIndex]
imgUrl := ""
imgBase64 := ""
@@ -267,6 +306,7 @@ func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, nil, startTime)
proxywasm.ResumeHttpResponse()
}
}

View File

@@ -212,11 +212,14 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI
contentIndex := 0
imageIndex := 0
sessionID, _ := utils.GenerateHexID(20)
currentSubmissionIndex := 0
currentImageSubmissionIndex := 0
var singleCall func()
var singleCallForImage func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
@@ -224,6 +227,7 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
@@ -232,10 +236,13 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass)
if contentIndex >= len(content) {
if len(images) > 0 && config.CheckRequestImage {
singleCallForImage()
} else {
cfg.WriteGuardrailLog(ctx)
proxywasm.ResumeHttpRequest()
}
} else {
@@ -246,6 +253,7 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
if err != nil {
log.Errorf("failed to build deny response body: %v", err)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
@@ -255,13 +263,15 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
if len(response.Data.Result) > 0 {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
}
singleCall = func() {
currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText)
var nextContentIndex int
if contentIndex+cfg.LengthLimit >= len(content) {
nextContentIndex = len(content)
@@ -275,6 +285,7 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, nil, startTime)
proxywasm.ResumeHttpRequest()
}
}
@@ -283,6 +294,7 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI
imageIndex += 1
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime)
if imageIndex < len(images) {
singleCallForImage()
} else {
@@ -294,6 +306,7 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime)
if imageIndex < len(images) {
singleCallForImage()
} else {
@@ -306,7 +319,10 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI
if imageIndex >= len(images) {
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentImageSubmissionIndex, responseBody, cfg.GuardrailResultPass)
if imageIndex >= len(images) {
cfg.WriteGuardrailLog(ctx)
proxywasm.ResumeHttpRequest()
} else {
singleCallForImage()
@@ -317,6 +333,7 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
if err != nil {
log.Errorf("failed to build deny response body: %v", err)
cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
@@ -325,13 +342,15 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
if len(response.Data.Result) > 0 {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
cfg.CompleteGuardrailSubmissionEvent(ctx, currentImageSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
}
singleCallForImage = func() {
currentImageSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityImage)
img := images[imageIndex]
imgUrl := ""
imgBase64 := ""
@@ -344,6 +363,7 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, nil, startTime)
proxywasm.ResumeHttpRequest()
}
}
@@ -365,11 +385,13 @@ func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.A
return types.ActionContinue
}
imageIndex := 0
currentSubmissionIndex := 0
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
imageIndex += 1
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
if imageIndex < len(imgUrls) {
singleCall()
} else {
@@ -381,6 +403,7 @@ func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.A
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
if imageIndex < len(imgUrls) {
singleCall()
} else {
@@ -391,9 +414,16 @@ func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.A
endTime := time.Now().UnixMilli()
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if imageIndex >= len(imgUrls) {
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
// Transitional double-write: this handler historically wrote safecheck_request_rt
// for response-phase emissions; existing dashboards key off the old name.
// Remove after 12 release cycles.
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
ctx.SetUserAttribute("safecheck_status", "response pass")
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass)
if imageIndex >= len(imgUrls) {
cfg.WriteGuardrailLog(ctx)
proxywasm.ResumeHttpResponse()
} else {
singleCall()
@@ -403,21 +433,31 @@ func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.A
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
if err != nil {
log.Errorf("failed to build deny response body: %v", err)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpResponse()
return
}
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, denyBody, -1)
config.IncrementCounter("ai_sec_response_deny", 1)
// Transitional double-write: this handler historically incremented ai_sec_request_deny
// and wrote safecheck_request_rt for response-phase denies; existing dashboards/alerts
// key off the old names. The legacy safecheck_status="reqeust deny" (typo) is dropped
// in this transition. Remove after 12 release cycles.
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
ctx.SetUserAttribute("safecheck_status", "response deny")
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
}
singleCall = func() {
currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseResponse, cfg.GuardrailModalityImage)
imgUrl := imgUrls[imageIndex]
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, "")
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, nil, startTime)
proxywasm.ResumeHttpResponse()
}
}

View File

@@ -43,10 +43,12 @@ func HandleMcpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig,
}
contentIndex := 0
sessionID, _ := utils.GenerateHexID(20)
currentSubmissionIndex := 0
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
@@ -54,6 +56,7 @@ func HandleMcpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig,
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
@@ -62,7 +65,10 @@ func HandleMcpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig,
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass)
if contentIndex >= len(content) {
cfg.WriteGuardrailLog(ctx)
proxywasm.ResumeHttpRequest()
} else {
singleCall()
@@ -74,22 +80,25 @@ func HandleMcpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig,
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request deny")
if response.Data.Advice != nil {
if len(response.Data.Result) > 0 {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
if err != nil {
log.Errorf("failed to build deny response body: %v", err)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
denyResponse := fmt.Sprintf(DenyResponse, marshalledDenyMessage)
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(denyResponse), -1)
}
singleCall = func() {
currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityMCP)
var nextContentIndex int
if contentIndex+cfg.LengthLimit >= len(content) {
nextContentIndex = len(content)
@@ -103,6 +112,7 @@ func HandleMcpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig,
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, nil, startTime)
proxywasm.ResumeHttpRequest()
}
}
@@ -114,6 +124,7 @@ func HandleMcpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig,
func HandleMcpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
consumer, _ := ctx.GetContext("consumer").(string)
var frontBuffer []byte
currentSubmissionIndex := 0
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
defer func() {
@@ -122,6 +133,9 @@ func HandleMcpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecuri
}()
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
ctx.SetUserAttribute("safecheck_status", "response error")
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultError)
cfg.WriteGuardrailLog(ctx)
proxywasm.InjectEncodedDataToFilterChain(frontBuffer, false)
return
}
@@ -129,6 +143,9 @@ func HandleMcpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecuri
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at response phase")
ctx.SetUserAttribute("safecheck_status", "response error")
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultError)
cfg.WriteGuardrailLog(ctx)
proxywasm.InjectEncodedDataToFilterChain(frontBuffer, false)
return
}
@@ -136,13 +153,20 @@ func HandleMcpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecuri
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
if err != nil {
log.Errorf("failed to build deny response body: %v", err)
ctx.SetUserAttribute("safecheck_status", "response error")
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultError)
cfg.WriteGuardrailLog(ctx)
proxywasm.InjectEncodedDataToFilterChain(frontBuffer, false)
return
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
denySSEResponse := fmt.Sprintf(DenySSEResponse, marshalledDenyMessage)
proxywasm.InjectEncodedDataToFilterChain([]byte(denySSEResponse), true)
} else {
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass)
cfg.WriteGuardrailLog(ctx)
proxywasm.InjectEncodedDataToFilterChain(frontBuffer, false)
}
}
@@ -158,10 +182,14 @@ func HandleMcpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecuri
ctx.SetContext("during_call", true)
checkService := config.GetResponseCheckService(consumer)
sessionID, _ := utils.GenerateHexID(20)
currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseResponse, cfg.GuardrailModalityMCP)
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, msg, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
ctx.SetUserAttribute("safecheck_status", "response error")
cfg.CompleteGuardrailSubmissionEventWithRequestID(ctx, currentSubmissionIndex, "", cfg.GuardrailResultError)
cfg.WriteGuardrailLog(ctx)
proxywasm.InjectEncodedDataToFilterChain(frontBuffer, false)
ctx.SetContext("during_call", false)
}
@@ -194,10 +222,12 @@ func HandleMcpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig,
}
contentIndex := 0
sessionID, _ := utils.GenerateHexID(20)
currentSubmissionIndex := 0
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpResponse()
return
}
@@ -205,6 +235,7 @@ func HandleMcpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig,
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at response phase")
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpResponse()
return
}
@@ -213,7 +244,10 @@ func HandleMcpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig,
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "response pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass)
if contentIndex >= len(content) {
cfg.WriteGuardrailLog(ctx)
proxywasm.ResumeHttpResponse()
} else {
singleCall()
@@ -224,17 +258,19 @@ func HandleMcpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig,
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "response deny")
if response.Data.Advice != nil {
if len(response.Data.Result) > 0 {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
if err != nil {
log.Errorf("failed to build deny response body: %v", err)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpResponse()
return
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
denyResponseBody := fmt.Sprintf(DenyResponse, marshalledDenyMessage)
proxywasm.RemoveHttpResponseHeader("content-length")
@@ -253,10 +289,12 @@ func HandleMcpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig,
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
checkService := config.GetResponseCheckService(consumer)
currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseResponse, cfg.GuardrailModalityMCP)
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, contentPiece, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, nil, startTime)
proxywasm.ResumeHttpResponse()
}
}

View File

@@ -2,7 +2,6 @@ package text
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
@@ -67,6 +66,8 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
hasMasked := false
maskedContent := []byte(content)
sessionID, _ := utils.GenerateHexID(20)
currentSubmissionIndex := 0
currentImageSubmissionIndex := 0
var singleCall func()
var singleCallForImage func()
// prevContentIndex tracks the start of the current chunk for masking replacement
@@ -74,6 +75,7 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
@@ -81,6 +83,7 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
@@ -97,26 +100,17 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
if replaceErr != nil {
log.Errorf("failed to replace request body content, falling back to block: %v", replaceErr)
// Fall back to block to prevent leaking sensitive data
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
if sendErr := cfg.SendFallbackDenyResponse(config, gjson.GetBytes(body, "stream").Bool()); sendErr != nil {
log.Errorf("failed to build deny response body: %v", sendErr)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
return
}
proxywasm.ReplaceHttpRequestBody(newBody)
@@ -125,10 +119,13 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
} else {
ctx.SetUserAttribute("safecheck_status", "request pass")
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass)
if contentIndex >= len(maskedContent) {
if len(images) > 0 && config.CheckRequestImage {
singleCallForImage()
} else {
cfg.WriteGuardrailLog(ctx)
proxywasm.ResumeHttpRequest()
}
} else {
@@ -140,6 +137,30 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
if desensitization == "" {
proxywasm.LogInfof("safecheck_action_source=mask_fallback_to_block, reason=empty_desensitization")
log.Warnf("desensitization content is empty, falling back to block logic")
// Keep this fallback separate from RiskBlock: legacy reuses the
// original deny body in content, while structured emits an empty
// fallback guardrail object.
isStream := gjson.GetBytes(body, "stream").Bool()
var sendErr error
if !config.ProtocolOriginal && config.OpenAIDenyResponseFormat != cfg.OpenAIDenyResponseFormatStructured {
sendErr = cfg.SendDenyResponse(config, response, consumer, isStream)
} else {
sendErr = cfg.SendFallbackDenyResponse(config, isStream)
}
if sendErr != nil {
log.Errorf("failed to build deny response body: %v", sendErr)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
return
} else {
// Replace only the current chunk portion in maskedContent
chunkStart := prevContentIndex
@@ -156,28 +177,19 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
if replaceErr != nil {
log.Errorf("failed to replace request body content, falling back to block: %v", replaceErr)
// Fall back to block to prevent leaking sensitive data
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
if sendErr := cfg.SendFallbackDenyResponse(config, gjson.GetBytes(body, "stream").Bool()); sendErr != nil {
log.Errorf("failed to build deny response body: %v", sendErr)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
return
}
proxywasm.ReplaceHttpRequestBody(newBody)
@@ -185,52 +197,41 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request mask")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultMask)
if len(images) > 0 && config.CheckRequestImage {
singleCallForImage()
} else {
cfg.WriteGuardrailLog(ctx)
proxywasm.ResumeHttpRequest()
}
} else {
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultMask)
singleCall()
}
return
}
// Fall through to block logic when desensitization is empty
fallthrough
case cfg.RiskBlock:
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
if err != nil {
if err := cfg.SendDenyResponse(config, response, consumer, gjson.GetBytes(body, "stream").Bool()); err != nil {
log.Errorf("failed to build deny response body: %v", err)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
if len(response.Data.Result) > 0 {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
}
}
singleCall = func() {
currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText)
prevContentIndex = contentIndex
var nextContentIndex int
if contentIndex+cfg.LengthLimit >= len(maskedContent) {
@@ -245,6 +246,7 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, nil, startTime)
proxywasm.ResumeHttpRequest()
}
}
@@ -253,6 +255,7 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
imageIndex += 1
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime)
if imageIndex < len(images) {
singleCallForImage()
} else {
@@ -264,6 +267,7 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime)
if imageIndex < len(images) {
singleCallForImage()
} else {
@@ -276,7 +280,10 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
if imageIndex >= len(images) {
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentImageSubmissionIndex, responseBody, cfg.GuardrailResultPass)
if imageIndex >= len(images) {
cfg.WriteGuardrailLog(ctx)
proxywasm.ResumeHttpRequest()
} else {
singleCallForImage()
@@ -284,36 +291,25 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
return
}
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
if err != nil {
if err := cfg.SendDenyResponse(config, response, consumer, gjson.GetBytes(body, "stream").Bool()); err != nil {
log.Errorf("failed to build deny response body: %v", err)
cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
if len(response.Data.Result) > 0 {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
cfg.CompleteGuardrailSubmissionEvent(ctx, currentImageSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
}
singleCallForImage = func() {
currentImageSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityImage)
img := images[imageIndex]
imgUrl := ""
imgBase64 := ""
@@ -326,6 +322,7 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, nil, startTime)
proxywasm.ResumeHttpRequest()
}
}

View File

@@ -2,7 +2,6 @@ package text
import (
"encoding/json"
"fmt"
"net/http"
"time"
@@ -27,10 +26,12 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
}
contentIndex := 0
sessionID, _ := utils.GenerateHexID(20)
currentSubmissionIndex := 0
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
@@ -38,6 +39,7 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at request phase")
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
@@ -46,44 +48,36 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass)
if contentIndex >= len(content) {
cfg.WriteGuardrailLog(ctx)
proxywasm.ResumeHttpRequest()
} else {
singleCall()
}
return
}
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
if err != nil {
if err := cfg.SendDenyResponse(config, response, consumer, gjson.GetBytes(body, "stream").Bool()); err != nil {
log.Errorf("failed to build deny response body: %v", err)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpRequest()
return
}
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
if len(response.Data.Result) > 0 {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
}
singleCall = func() {
currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText)
var nextContentIndex int
if contentIndex+cfg.LengthLimit >= len(content) {
nextContentIndex = len(content)
@@ -97,6 +91,7 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, nil, startTime)
proxywasm.ResumeHttpRequest()
}
}

View File

@@ -312,6 +312,29 @@ var protocolOriginalConfig = func() json.RawMessage {
return data
}()
func withConfigOverrides(base json.RawMessage, overrides map[string]interface{}) json.RawMessage {
var config map[string]interface{}
_ = json.Unmarshal(base, &config)
for k, v := range overrides {
config[k] = v
}
data, _ := json.Marshal(config)
return data
}
func withStructuredFormat(base json.RawMessage) json.RawMessage {
return withConfigOverrides(base, map[string]interface{}{
"openAIDenyResponseFormat": string(cfg.OpenAIDenyResponseFormatStructured),
})
}
func mustDecodeLegacyDenyContent(t *testing.T, content string) cfg.DenyResponseBody {
t.Helper()
var denyBody cfg.DenyResponseBody
require.NoError(t, json.Unmarshal([]byte(content), &denyBody))
return denyBody
}
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基础配置解析
@@ -335,6 +358,69 @@ func TestParseConfig(t *testing.T) {
require.Equal(t, 1000, securityConfig.BufferLimit)
require.Equal(t, cfg.DefaultResponseFallbackJsonPaths(), securityConfig.ResponseContentFallbackJsonPaths)
require.Equal(t, cfg.DefaultStreamingResponseFallbackJsonPaths(), securityConfig.ResponseStreamContentFallbackJsonPaths)
require.Equal(t, cfg.OpenAIDenyResponseFormatLegacy, securityConfig.OpenAIDenyResponseFormat)
})
t.Run("openai deny response format explicit legacy", func(t *testing.T) {
host, status := test.NewTestHost(withConfigOverrides(requestOnlyConfig, map[string]interface{}{
"openAIDenyResponseFormat": "legacy",
}))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, cfg.OpenAIDenyResponseFormatLegacy, securityConfig.OpenAIDenyResponseFormat)
})
t.Run("openai deny response format explicit structured", func(t *testing.T) {
host, status := test.NewTestHost(withStructuredFormat(requestOnlyConfig))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, cfg.OpenAIDenyResponseFormatStructured, securityConfig.OpenAIDenyResponseFormat)
})
t.Run("invalid openai deny response format", func(t *testing.T) {
host, status := test.NewTestHost(withConfigOverrides(requestOnlyConfig, map[string]interface{}{
"openAIDenyResponseFormat": "json",
}))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
t.Run("empty openai deny response format is invalid", func(t *testing.T) {
host, status := test.NewTestHost(withConfigOverrides(requestOnlyConfig, map[string]interface{}{
"openAIDenyResponseFormat": "",
}))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
t.Run("consumer risk level cannot override openai deny response format", func(t *testing.T) {
configJSON, err := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"action": "MultiModalGuard",
"contentModerationLevelBar": "high",
"consumerRiskLevel": []map[string]interface{}{
{
"name": "consumer-a",
"matchType": "exact",
"openAIDenyResponseFormat": "structured",
},
},
})
require.NoError(t, err)
var securityConfig cfg.AISecurityConfig
parseErr := securityConfig.Parse(gjson.ParseBytes(configJSON))
require.EqualError(t, parseErr, cfg.OpenAIDenyResponseFormatConsumerScopeError)
})
// 测试仅检查请求的配置
@@ -509,7 +595,7 @@ func TestOnHttpRequestHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试启用请求检查的情况
t.Run("request checking enabled", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
host, status := test.NewTestHost(withStructuredFormat(basicConfig))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
@@ -547,7 +633,7 @@ func TestOnHttpRequestBody(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试请求体安全检查通过
t.Run("request body security check pass", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
host, status := test.NewTestHost(withStructuredFormat(basicConfig))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
@@ -2655,7 +2741,7 @@ func TestTextModerationPlusResponseDeny(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// TextModerationPlus response deny → exercises text_moderation_plus/text (via common/text) BuildDenyResponseBody response path
t.Run("text moderation plus response deny returns blockedDetails", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
host, status := test.NewTestHost(withStructuredFormat(basicConfig))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
@@ -2684,22 +2770,26 @@ func TestTextModerationPlusResponseDeny(t *testing.T) {
require.NotNil(t, local, "expected SendHttpResponse for response deny")
require.Contains(t, string(local.Data), "blockedDetails")
// Verify OpenAI completion shape wrapper
// Verify OpenAI completion shape wrapper in structured mode:
// message.content carries only the human-readable deny text and the
// structured deny payload moves to choices[0].x_higress_guardrail.
type openAIChatCompletion struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
Guardrail cfg.DenyResponseBody `json:"x_higress_guardrail"`
} `json:"choices"`
}
var outer openAIChatCompletion
require.NoError(t, json.Unmarshal(local.Data, &outer))
require.Len(t, outer.Choices, 1)
var deny cfg.DenyResponseBody
require.NoError(t, json.Unmarshal([]byte(outer.Choices[0].Message.Content), &deny))
require.Equal(t, 200, deny.Code)
require.NotEmpty(t, deny.BlockedDetails)
require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Message.Content)
require.Equal(t, 200, outer.Choices[0].Guardrail.Code)
require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Guardrail.DenyMessage)
require.NotEmpty(t, outer.Choices[0].Guardrail.BlockedDetails)
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists())
})
})
}
@@ -3289,3 +3379,835 @@ func TestMultiModalGuardMaskStreamDeny(t *testing.T) {
})
})
}
func TestOpenAIDenyLegacyDefaultNonStream(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("default legacy non-stream response keeps deny body in content", func(t *testing.T) {
host, status := test.NewTestHost(multiModalGuardTextConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-legacy-default", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local)
require.Equal(t, "stop", gjson.GetBytes(local.Data, "choices.0.finish_reason").String())
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").Exists())
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists())
content := gjson.GetBytes(local.Data, "choices.0.message.content").String()
denyBody := mustDecodeLegacyDenyContent(t, content)
require.Equal(t, 200, denyBody.Code)
require.NotEmpty(t, denyBody.BlockedDetails)
})
})
}
func TestOpenAIDenyLegacyDefaultStream(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("default legacy stream response keeps deny body in first content frame", func(t *testing.T) {
host, status := test.NewTestHost(multiModalGuardTextConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "trigger deny"}], "stream": true}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-legacy-stream", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local)
raw := string(local.Data)
require.True(t, strings.HasSuffix(strings.TrimSpace(raw), "data: [DONE]"))
parts := strings.Split(raw, "\n\n")
require.GreaterOrEqual(t, len(parts), 3)
firstFrame := strings.TrimSpace(strings.TrimPrefix(parts[0], "data:"))
endFrame := strings.TrimSpace(strings.TrimPrefix(parts[1], "data:"))
firstContent := gjson.Get(firstFrame, "choices.0.delta.content").String()
denyBody := mustDecodeLegacyDenyContent(t, firstContent)
require.Equal(t, 200, denyBody.Code)
require.NotEmpty(t, denyBody.BlockedDetails)
require.False(t, gjson.Get(firstFrame, "choices.0.x_higress_guardrail").Exists())
require.False(t, gjson.Get(firstFrame, "choices.0.x_higress").Exists())
require.False(t, gjson.Get(endFrame, "choices.0.x_higress_guardrail").Exists())
require.False(t, gjson.Get(endFrame, "choices.0.x_higress").Exists())
require.Equal(t, "stop", gjson.Get(endFrame, "choices.0.finish_reason").String())
})
})
}
func TestOpenAIDenyLegacyDenyCodeKeepsResponseCodeInContent(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("legacy content code remains safecheck response code when denyCode differs", func(t *testing.T) {
host, status := test.NewTestHost(withConfigOverrides(multiModalGuardTextConfig, map[string]interface{}{
"denyCode": 451,
}))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-legacy-451", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local)
require.Equal(t, uint32(451), local.StatusCode)
content := gjson.GetBytes(local.Data, "choices.0.message.content").String()
denyBody := mustDecodeLegacyDenyContent(t, content)
require.Equal(t, 200, denyBody.Code)
})
})
}
func TestMaskEmptyDesensitizationOpenAIFormats(t *testing.T) {
securityResponse := `{
"Code": 200, "Message": "Success", "RequestId": "req-mask-empty-openai",
"Data": {
"RiskLevel": "none",
"Detail": [{
"Suggestion": "mask", "Type": "sensitiveData", "Level": "S3",
"Result": [{"Label": "phone", "Confidence": 99.0,
"Ext": {"Desensitization": ""}}]
}]
}
}`
test.RunTest(t, func(t *testing.T) {
t.Run("legacy empty desensitization uses json-stringified deny body", func(t *testing.T) {
host, status := test.NewTestHost(maskConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "敏感内容"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "application/json"}}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local)
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").Exists())
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists())
content := gjson.GetBytes(local.Data, "choices.0.message.content").String()
denyBody := mustDecodeLegacyDenyContent(t, content)
require.Equal(t, 200, denyBody.Code)
require.Empty(t, denyBody.BlockedDetails)
})
t.Run("structured empty desensitization uses fallback guardrail", func(t *testing.T) {
host, status := test.NewTestHost(withStructuredFormat(maskConfig))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "敏感内容"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "application/json"}}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local)
require.Equal(t, cfg.DefaultDenyMessage, gjson.GetBytes(local.Data, "choices.0.message.content").String())
require.True(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").IsObject())
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists())
require.Equal(t, int64(0), gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail.blockedDetails.#").Int())
})
})
}
func TestMaskReplaceJsonFieldFailureOpenAIFormats(t *testing.T) {
securityResponse := `{
"Code": 200, "Message": "Success", "RequestId": "req-mask-replace-failure",
"Data": {
"RiskLevel": "none",
"Detail": [{
"Suggestion": "mask", "Type": "sensitiveData", "Level": "S3",
"Result": [{"Label": "phone", "Confidence": 99.0,
"Ext": {"Desensitization": "masked"}}]
}]
}
}`
test.RunTest(t, func(t *testing.T) {
t.Run("legacy replace failure keeps pure deny message content", func(t *testing.T) {
host, status := test.NewTestHost(withConfigOverrides(maskConfig, map[string]interface{}{
"requestContentJsonPath": "@this.messages.0.content",
}))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "敏感内容"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "application/json"}}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local)
require.Equal(t, cfg.DefaultDenyMessage, gjson.GetBytes(local.Data, "choices.0.message.content").String())
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").Exists())
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists())
})
t.Run("structured replace failure emits fallback guardrail", func(t *testing.T) {
host, status := test.NewTestHost(withConfigOverrides(withStructuredFormat(maskConfig), map[string]interface{}{
"requestContentJsonPath": "@this.messages.0.content",
}))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "敏感内容"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "application/json"}}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local)
require.Equal(t, cfg.DefaultDenyMessage, gjson.GetBytes(local.Data, "choices.0.message.content").String())
require.True(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").IsObject())
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists())
require.Equal(t, int64(0), gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail.blockedDetails.#").Int())
})
})
}
// =============================================================================
// x_higress_guardrail 扩展字段链路与模板测试
// =============================================================================
// openAIChoiceWithGuardrail is the minimal OpenAI choice shape used by
// x_higress_guardrail assertions. Guardrail is unmarshaled directly into the strongly
// typed cfg.DenyResponseBody so tests assert the documented contract — code
// is int, denyMessage is string, blockedDetails is a slice — instead of
// silently tolerating shape drift through map[string]interface{}.
type openAIChoiceWithGuardrail struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
Delta struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
Guardrail cfg.DenyResponseBody `json:"x_higress_guardrail"`
}
// openAIBodyWithGuardrail also carries top-level extension fields so tests can
// assert the design contract that x_higress_guardrail lives ONLY inside choices[0] —
// a regression where it leaks to the body root would deserialize into this
// field and the require.Empty check at the call site would fail.
type openAIBodyWithGuardrail struct {
Choices []openAIChoiceWithGuardrail `json:"choices"`
Guardrail *cfg.DenyResponseBody `json:"x_higress_guardrail,omitempty"`
XHigress *cfg.DenyResponseBody `json:"x_higress,omitempty"`
}
// TestRequestDenyGuardrailNonStream verifies that 请求阶段非流式 deny renders
// content as plain text and embeds x_higress_guardrail as a JSON object.
func TestRequestDenyGuardrailNonStream(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("non-stream request deny carries guardrail object", func(t *testing.T) {
host, status := test.NewTestHost(withStructuredFormat(multiModalGuardTextConfig))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-x-higress", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local)
// x_higress_guardrail must be a JSON object, not a string
require.True(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").IsObject(),
"x_higress_guardrail should be an embedded JSON object, not a string literal")
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists(),
"structured deny must not emit the old choices[0].x_higress field")
var outer openAIBodyWithGuardrail
require.NoError(t, json.Unmarshal(local.Data, &outer))
require.Len(t, outer.Choices, 1)
require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Message.Content,
"content should carry only the human-readable deny text")
require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Guardrail.DenyMessage)
// A-R1F4 contract: x_higress_guardrail.code carries the gateway-emitted HTTP deny
// status (config.DenyCode, default 200), NOT the upstream security
// service's Response.Code. multiModalGuardTextConfig leaves denyCode
// unset, so this resolves to cfg.DefaultDenyCode.
require.Equal(t, int(cfg.DefaultDenyCode), outer.Choices[0].Guardrail.Code,
"x_higress_guardrail.code must equal the gateway-emitted HTTP deny status")
require.NotNil(t, outer.Choices[0].Guardrail.BlockedDetails)
// Design contract: x_higress_guardrail lives ONLY nested under choices[0].
require.Nil(t, outer.Guardrail,
"x_higress_guardrail must not leak to body root; only choices[0].x_higress_guardrail is valid")
require.Nil(t, outer.XHigress, "old x_higress must not leak to body root")
require.False(t, gjson.GetBytes(local.Data, "x_higress_guardrail").Exists(),
"x_higress_guardrail must not leak to body root; only choices[0].x_higress_guardrail is valid")
require.False(t, gjson.GetBytes(local.Data, "x_higress").Exists(),
"old x_higress must not leak to body root")
})
})
}
// TestRequestDenyGuardrailStreamFrames verifies that 请求阶段流式 deny only
// embeds x_higress_guardrail in the final chunk and the first chunk carries plain text.
func TestRequestDenyGuardrailStreamFrames(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("stream request deny only attaches guardrail in last chunk", func(t *testing.T) {
host, status := test.NewTestHost(withStructuredFormat(multiModalGuardTextConfig))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "trigger deny"}], "stream": true}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-stream-x-higress", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local)
raw := string(local.Data)
require.True(t, strings.HasSuffix(strings.TrimSpace(raw), "data: [DONE]"))
parts := strings.Split(raw, "\n\n")
// 模板格式: data:<chunk>\n\ndata:<end>\n\ndata: [DONE]
require.GreaterOrEqual(t, len(parts), 3, "expected at least chunk + end + DONE")
firstFrame := strings.TrimPrefix(parts[0], "data:")
endFrame := strings.TrimPrefix(parts[1], "data:")
require.False(t, gjson.Get(firstFrame, "choices.0.x_higress_guardrail").Exists(),
"first chunk should not carry x_higress_guardrail")
require.False(t, gjson.Get(firstFrame, "choices.0.x_higress").Exists(),
"first chunk should not carry old x_higress")
require.Equal(t, cfg.DefaultDenyMessage,
gjson.Get(firstFrame, "choices.0.delta.content").String(),
"first chunk delta.content should be plain text")
require.True(t, gjson.Get(endFrame, "choices.0.x_higress_guardrail").IsObject(),
"final chunk should carry x_higress_guardrail as object")
require.False(t, gjson.Get(endFrame, "choices.0.x_higress").Exists(),
"final chunk should not carry old x_higress")
// Deny stream's terminator carries `stop` for wire-level compatibility
// with downstream consumers (LangChain / LiteLLM / SDKs / BI) that key
// off `stop` as a valid completion. The moderation-event signal lives
// in choices[0].x_higress_guardrail (denyCode / blockedDetails) instead.
require.Equal(t, "stop", gjson.Get(endFrame, "choices.0.finish_reason").String())
require.False(t, gjson.Get(endFrame, "choices.0.delta.content").Exists(),
"final chunk delta should be empty")
// Design contract: x_higress_guardrail lives ONLY nested under choices[0].
require.False(t, gjson.Get(endFrame, "x_higress_guardrail").Exists(),
"x_higress_guardrail must not leak to body root of the end frame")
require.False(t, gjson.Get(firstFrame, "x_higress_guardrail").Exists(),
"x_higress_guardrail must not leak to body root of the first frame")
require.False(t, gjson.Get(endFrame, "x_higress").Exists(),
"old x_higress must not leak to body root of the end frame")
require.False(t, gjson.Get(firstFrame, "x_higress").Exists(),
"old x_higress must not leak to body root of the first frame")
require.Contains(t, parts[2], "[DONE]")
})
})
}
// TestRequestDenyDefaultDenyMessage verifies that without a configured
// denyMessage, both content and x_higress_guardrail.denyMessage fall back to the default
// text.
func TestRequestDenyDefaultDenyMessage(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("default deny message used when not configured", func(t *testing.T) {
host, status := test.NewTestHost(withStructuredFormat(multiModalGuardTextConfig))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-default", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local)
content := gjson.GetBytes(local.Data, "choices.0.message.content").String()
denyMessage := gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail.denyMessage").String()
require.Equal(t, cfg.DefaultDenyMessage, content)
require.Equal(t, cfg.DefaultDenyMessage, denyMessage)
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists())
})
})
}
// TestProtocolOriginalDenyShapePreserved guards the regression that the
// protocol: "original" normal deny path keeps the bare DenyResponseBody shape
// without OpenAI wrapping or x_higress_guardrail.
func TestProtocolOriginalDenyShapePreserved(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
for _, tc := range []struct {
name string
config json.RawMessage
}{
{name: "default format", config: protocolOriginalConfig},
{name: "legacy format", config: withConfigOverrides(protocolOriginalConfig, map[string]interface{}{"openAIDenyResponseFormat": "legacy"})},
{name: "structured format", config: withStructuredFormat(protocolOriginalConfig)},
} {
tc := tc
t.Run("original protocol deny stays as bare DenyResponseBody "+tc.name, func(t *testing.T) {
host, status := test.NewTestHost(tc.config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-orig-shape", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local)
require.False(t, gjson.GetBytes(local.Data, "choices").Exists(),
"original protocol should not OpenAI-wrap the body")
require.False(t, gjson.GetBytes(local.Data, "x_higress_guardrail").Exists(),
"original protocol should not introduce x_higress_guardrail")
require.False(t, gjson.GetBytes(local.Data, "x_higress").Exists(),
"original protocol should not introduce old x_higress")
require.True(t, gjson.GetBytes(local.Data, "blockedDetails").Exists())
require.Equal(t, int64(200), gjson.GetBytes(local.Data, "code").Int())
})
}
})
}
// TestMaskEmptyDesensitizationOriginalShape guards the A-R1F2 alignment:
// when riskAction=mask and the upstream returns empty Desensitization under
// protocol: "original", the response body must be the bare JSON string literal
// produced by wrapper.MarshalStr(ResolveDenyMessage(config)) — i.e.
// `"<deny message>"` — mirroring the ReplaceJsonFieldTextContent failure path
// at lvwang/multi_modal_guard/text/openai.go:102 / :159.
//
// Before A-R1F2 this branch fell through to RiskBlock and called
// BuildDenyResponseBody, returning a structured {code, blockedDetails, ...}
// object instead. The fallthrough was inconsistent with design Section 5 and
// has been replaced with the self-handled MarshalStr path; this regression
// test locks in the new contract so the divergence cannot reappear silently.
func TestMaskEmptyDesensitizationOriginalShape(t *testing.T) {
maskOriginalConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"checkResponse": false,
"action": "MultiModalGuard",
"riskAction": "mask",
"protocol": "original",
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"timeout": 2000,
})
return data
}()
test.RunTest(t, func(t *testing.T) {
t.Run("mask empty desensitization under original emits MarshalStr literal", func(t *testing.T) {
host, status := test.NewTestHost(maskOriginalConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "敏感内容"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
// mask 但脱敏内容空 → 走 A-R1F2 自处理分支
securityResponse := `{
"Code": 200, "Message": "Success", "RequestId": "req-mask-empty-orig",
"Data": {
"RiskLevel": "none",
"Detail": [{
"Suggestion": "mask", "Type": "sensitiveData", "Level": "S3",
"Result": [{"Label": "phone", "Confidence": 99.0,
"Ext": {"Desensitization": ""}}]
}]
}
}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local)
// 不是 OpenAI 包装
require.False(t, gjson.GetBytes(local.Data, "choices").Exists(),
"mask empty-desensitization under original must not be OpenAI-wrapped")
require.False(t, gjson.GetBytes(local.Data, "object").Exists(),
"mask empty-desensitization under original must not include OpenAI 'object' field")
// 不再是 DenyResponseBody {code, blockedDetails} 结构;
// 现在是 wrapper.MarshalStr 产物 —— 裸 JSON 字符串字面量
require.False(t, gjson.GetBytes(local.Data, "code").Exists(),
"A-R1F2: body should now be a JSON string literal, not a {code, blockedDetails} object")
require.False(t, gjson.GetBytes(local.Data, "blockedDetails").Exists(),
"A-R1F2: body should now be a JSON string literal, not a {code, blockedDetails} object")
// 实际形态:wrapper.MarshalStr 产物。该 wrapper 在 Higress 的实现里返回
// 已剥除外层双引号的字符串(见 C-R1F9 备注),所以 body 是原始 deny 文本
// 字节(不可 json.Unmarshal 回字符串)。
require.Equal(t, cfg.DefaultDenyMessage, string(local.Data),
"body should equal raw deny message (wrapper.MarshalStr strips outer quotes — C-R1F9)")
})
})
}
// TestResponseStreamingDenyGuardrail drives HandleTextGenerationStreamingResponseBody
// (lvwang/common/text/openai.go) — the only response-side structured stream
// writer — and asserts the injected SSE carries:
// - first chunk: human-readable content, no x_higress_guardrail
// - end chunk: x_higress_guardrail as a JSON object with code/denyMessage/blockedDetails
// - terminator: "data: [DONE]"
func TestResponseStreamingDenyGuardrail(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("response streaming deny injects guardrail only in last frame", func(t *testing.T) {
host, status := test.NewTestHost(withStructuredFormat(basicConfig))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// Skip request-phase check by using a non-deny request body.
body := `{"messages": [{"role": "user", "content": "hello"}]}`
host.CallOnHttpRequestBody([]byte(body))
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(`{"Code": 200, "Message": "Success", "RequestId": "req-stream-resp-pass", "Data": {"RiskLevel": "none"}}`))
action := host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "text/event-stream"},
})
require.Equal(t, types.ActionContinue, action)
// Single chunk + end_of_stream=true triggers the security check.
chunk := []byte("data: {\"choices\":[{\"delta\":{\"content\":\"bad response\"}}]}\n\n")
host.CallOnHttpStreamingResponseBody(chunk, true)
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-stream-resp-deny", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
injected := host.GetResponseBody()
require.NotEmpty(t, injected, "expected InjectEncodedDataToFilterChain to deliver SSE deny payload")
injectedStr := string(injected)
require.True(t, strings.HasSuffix(strings.TrimSpace(injectedStr), "data: [DONE]"),
"expected SSE stream to end with [DONE], got: %s", injectedStr)
// Strip "data:" prefixes and split into events.
events := strings.Split(strings.TrimSpace(injectedStr), "\n\n")
require.GreaterOrEqual(t, len(events), 2, "expected at least first chunk + end chunk")
// First event: content present, no x_higress_guardrail.
firstPayload := strings.TrimPrefix(events[0], "data:")
firstPayload = strings.TrimSpace(firstPayload)
require.Equal(t, cfg.DefaultDenyMessage, gjson.Get(firstPayload, "choices.0.delta.content").String())
require.False(t, gjson.Get(firstPayload, "choices.0.x_higress_guardrail").Exists(),
"first chunk must NOT carry x_higress_guardrail")
require.False(t, gjson.Get(firstPayload, "choices.0.x_higress").Exists(),
"first chunk must NOT carry old x_higress")
// Second event: x_higress_guardrail as JSON object on choices[0].
secondPayload := strings.TrimPrefix(events[1], "data:")
secondPayload = strings.TrimSpace(secondPayload)
guardrail := gjson.Get(secondPayload, "choices.0.x_higress_guardrail")
require.True(t, guardrail.Exists(), "end chunk must carry x_higress_guardrail")
require.True(t, guardrail.IsObject(), "x_higress_guardrail must be a JSON object, not a string")
require.Equal(t, cfg.DefaultDenyMessage, guardrail.Get("denyMessage").String())
require.True(t, guardrail.Get("blockedDetails").Exists())
require.False(t, gjson.Get(secondPayload, "choices.0.x_higress").Exists(),
"end chunk must not carry old x_higress")
// Streaming deny terminator carries `stop` for wire-level compatibility;
// moderation signal lives in choices[0].x_higress_guardrail.
require.Equal(t, "stop", gjson.Get(secondPayload, "choices.0.finish_reason").String())
// Design contract: x_higress_guardrail lives ONLY nested under choices[0].
require.False(t, gjson.Get(secondPayload, "x_higress_guardrail").Exists(),
"x_higress_guardrail must not leak to body root of the end chunk")
require.False(t, gjson.Get(firstPayload, "x_higress_guardrail").Exists(),
"x_higress_guardrail must not leak to body root of the first chunk")
require.False(t, gjson.Get(secondPayload, "x_higress").Exists(),
"old x_higress must not leak to body root of the end chunk")
require.False(t, gjson.Get(firstPayload, "x_higress").Exists(),
"old x_higress must not leak to body root of the first chunk")
})
})
}
// A-R2-16(a): multi_modal_guard 图像审核 deny 通道未被前文 guardrail 测试覆盖,
// 而 R1-F6 修复(BuildOpenAIFallbackDenyResponseBody err 不再静默)依赖
// callbackForImage 与 singleCallForImage 调用 BuildOpenAIDenyResponseBody。
// 本测试发送纯 image_url 请求体直接命中 singleCallForImage 路径,断言图像
// 审核 deny 同样把 x_higress_guardrail 嵌入 choices[0],与文本 deny 形态对称。
//
// 文件:lvwang/multi_modal_guard/text/openai.go:299-369(callbackForImage / OpenAI 包装)
func TestImageDenyGuardrailShape(t *testing.T) {
imageCheckConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"checkRequestImage": true,
"action": "MultiModalGuard",
"apiType": "text_generation",
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"timeout": 2000,
"bufferLimit": 1000,
})
return data
}()
test.RunTest(t, func(t *testing.T) {
t.Run("multi_modal_guard image deny embeds guardrail on choices[0]", func(t *testing.T) {
host, status := test.NewTestHost(withStructuredFormat(imageCheckConfig))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 纯 image_url 内容:content==""parseContent 跳过文本审核直接走
// singleCallForImage让 callbackForImage 渲染 deny。
body := `{"messages":[{"role":"user","content":[{"type":"image_url","image_url":{"url":"https://example.com/bad.jpg"}}]}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-deny", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local, "expected SendHttpResponse for image deny")
// x_higress_guardrail 必须作为对象嵌在 choices[0] 内
require.True(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").IsObject(),
"image deny should embed x_higress_guardrail as JSON object on choices[0]")
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists(),
"image deny should not emit old x_higress on choices[0]")
var outer openAIBodyWithGuardrail
require.NoError(t, json.Unmarshal(local.Data, &outer))
require.Len(t, outer.Choices, 1)
require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Message.Content,
"content carries human-readable deny text")
require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Guardrail.DenyMessage)
require.Equal(t, int(cfg.DefaultDenyCode), outer.Choices[0].Guardrail.Code,
"x_higress_guardrail.code = gateway HTTP deny status (A-R1F4 contract)")
require.NotNil(t, outer.Choices[0].Guardrail.BlockedDetails)
// 不可泄漏到 body 根
require.Nil(t, outer.Guardrail, "x_higress_guardrail must not leak to body root")
require.Nil(t, outer.XHigress, "old x_higress must not leak to body root")
require.False(t, gjson.GetBytes(local.Data, "x_higress_guardrail").Exists(),
"x_higress_guardrail must not leak to body root")
require.False(t, gjson.GetBytes(local.Data, "x_higress").Exists(),
"old x_higress must not leak to body root")
})
})
}
// A-R2-16(b): text_moderation_plus 请求阶段 deny 通道未被前文 guardrail 测试覆盖。
// 已有 TestTextModerationPlusResponseDeny 覆盖响应阶段,本测试补齐请求阶段对称
// 用例,确保 OpenAI 协议下 structured deny body 把 x_higress_guardrail 放在 choices[0]。
//
// 文件:lvwang/text_moderation_plus/text/openai.go:56-92(请求阶段 deny 渲染)
func TestTextModerationPlusRequestDenyGuardrailShape(t *testing.T) {
tmpRequestConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"action": "TextModerationPlus",
"contentModerationLevelBar": "high",
"timeout": 2000,
})
return data
}()
test.RunTest(t, func(t *testing.T) {
t.Run("text_moderation_plus request deny embeds guardrail on choices[0]", func(t *testing.T) {
host, status := test.NewTestHost(withStructuredFormat(tmpRequestConfig))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-tmp-deny", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local, "expected SendHttpResponse for text_moderation_plus request deny")
require.True(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").IsObject(),
"text_moderation_plus request deny should embed x_higress_guardrail as JSON object on choices[0]")
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists(),
"text_moderation_plus request deny should not emit old x_higress")
var outer openAIBodyWithGuardrail
require.NoError(t, json.Unmarshal(local.Data, &outer))
require.Len(t, outer.Choices, 1)
require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Message.Content)
require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Guardrail.DenyMessage)
require.Equal(t, int(cfg.DefaultDenyCode), outer.Choices[0].Guardrail.Code,
"x_higress_guardrail.code = gateway HTTP deny status (A-R1F4 contract)")
require.NotNil(t, outer.Choices[0].Guardrail.BlockedDetails)
require.Nil(t, outer.Guardrail, "x_higress_guardrail must not leak to body root")
require.Nil(t, outer.XHigress, "old x_higress must not leak to body root")
require.False(t, gjson.GetBytes(local.Data, "x_higress_guardrail").Exists(),
"x_higress_guardrail must not leak to body root")
require.False(t, gjson.GetBytes(local.Data, "x_higress").Exists(),
"old x_higress must not leak to body root")
})
})
}

View File

@@ -58,7 +58,7 @@ func ReplaceJsonFieldTextContent(body []byte, jsonPath string, newContent string
fieldValue := gjson.GetBytes(body, resolved)
if !fieldValue.IsArray() {
// Simple string content — replace directly
return sjson.SetBytes(body, resolved, newContent)
return setJsonTextContent(body, resolved, newContent)
}
// Array content (multimodal): replace text items, preserve others
result := body
@@ -86,7 +86,7 @@ func ReplaceJsonFieldTextContent(body []byte, jsonPath string, newContent string
// If there's only one text item, put all desensitized content there
if len(textEntries) == 1 {
itemPath := fmt.Sprintf("%s.%d.text", resolved, textEntries[0].index)
return sjson.SetBytes(result, itemPath, newContent)
return setJsonTextContent(result, itemPath, newContent)
}
// Multiple text items: split desensitized content proportionally by original lengths
for j, entry := range textEntries {
@@ -117,7 +117,7 @@ func ReplaceJsonFieldTextContent(body []byte, jsonPath string, newContent string
remaining = remaining[byteOffset:]
}
itemPath := fmt.Sprintf("%s.%d.text", resolved, entry.index)
result, err = sjson.SetBytes(result, itemPath, replacement)
result, err = setJsonTextContent(result, itemPath, replacement)
if err != nil {
return nil, err
}
@@ -125,6 +125,18 @@ func ReplaceJsonFieldTextContent(body []byte, jsonPath string, newContent string
return result, nil
}
func setJsonTextContent(body []byte, jsonPath string, newContent string) ([]byte, error) {
current := gjson.GetBytes(body, jsonPath)
result, err := sjson.SetBytes(body, jsonPath, newContent)
if err != nil {
return nil, err
}
if current.Exists() && current.String() != newContent && bytes.Equal(result, body) {
return nil, fmt.Errorf("failed to replace json path %q", jsonPath)
}
return result, nil
}
// resolveJsonPath converts gjson modifier paths (e.g. "messages.@reverse.0.content")
// into concrete index paths (e.g. "messages.2.content") that sjson can handle.
func resolveJsonPath(body []byte, jsonPath string) string {

View File

@@ -263,6 +263,14 @@ func TestResolveJsonPathEdgeCases(t *testing.T) {
})
}
func TestReplaceJsonFieldTextContentReportsReadableNoOp(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"敏感内容"}]}`)
result, err := ReplaceJsonFieldTextContent(body, "@this.messages.0.content", "masked")
if err == nil {
t.Fatalf("expected error for readable path that sjson leaves unchanged, got nil with %s", string(result))
}
}
// TestReplaceJsonFieldContent covers the simple ReplaceJsonFieldContent function
func TestReplaceJsonFieldContent(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"original"}]}`)