From c21a38e783b08ececea6fe32dc9300a9a990ebfb Mon Sep 17 00:00:00 2001 From: JianweiWang Date: Fri, 29 May 2026 10:45:10 +0800 Subject: [PATCH] feat(ai-security-guard): structured x_higress deny response, error-path metrics, and AI logging (#3894) Co-authored-by: github-actions[bot] Co-authored-by: Claude Opus 4.7 Co-authored-by: rinfx --- .../extensions/ai-security-guard/README.md | 122 ++- .../extensions/ai-security-guard/README_EN.md | 96 +- .../ai-security-guard/ai_log_test.go | 454 +++++++++ .../ai-security-guard/config/ai_log.go | 126 +++ .../ai-security-guard/config/config.go | 193 +++- .../lvwang/common/text/openai.go | 77 +- .../lvwang/multi_modal_guard/image/openai.go | 60 +- .../lvwang/multi_modal_guard/image/qwen.go | 60 +- .../lvwang/multi_modal_guard/mcp/mcp.go | 50 +- .../lvwang/multi_modal_guard/text/openai.go | 141 ++- .../text_moderation_plus/text/openai.go | 33 +- .../extensions/ai-security-guard/main_test.go | 938 +++++++++++++++++- .../ai-security-guard/utils/utils.go | 18 +- .../ai-security-guard/utils/utils_test.go | 8 + 14 files changed, 2181 insertions(+), 195 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-security-guard/ai_log_test.go create mode 100644 plugins/wasm-go/extensions/ai-security-guard/config/ai_log.go diff --git a/plugins/wasm-go/extensions/ai-security-guard/README.md b/plugins/wasm-go/extensions/ai-security-guard/README.md index 5d547e9f3..02fa6a25e 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README.md @@ -34,6 +34,7 @@ description: 阿里云内容安全检测 | `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 | | `denyMessage` | string | optional | openai格式的流式/非流式响应 | 指定内容非法时的响应内容 | | `protocol` | string | optional | openai | 协议格式,非openai协议填`original` | +| `openAIDenyResponseFormat` | string | optional | legacy | OpenAI 包装拒答的响应形态,取值为 `legacy` 或 `structured`。默认 `legacy` 保持历史兼容;配置为 `structured` 时在 `choices[0].x_higress_guardrail` 输出结构化拦截详情 | | `contentModerationLevelBar` | string | optional | max | 内容合规检测拦截风险等级,取值为 `max`, `high`, `medium` or `low` | | `promptAttackLevelBar` | string | optional | max | 提示词攻击检测拦截风险等级,取值为 `max`, `high`, `medium` or `low` | | `sensitiveDataLevelBar` | string | optional | S4 | 敏感内容检测拦截风险等级,取值为 `S4`, `S3`, `S2` or `S1` | @@ -47,19 +48,18 @@ description: 阿里云内容安全检测 ### 拒绝响应结构 -内容被拦截时,插件(`MultiModalGuard` action)统一返回以下结构化 JSON 对象,各协议的承载位置如下: +内容被拦截时,插件(`MultiModalGuard` action)会构造以下结构化 JSON 对象。`protocol: original`、MCP 与图像生成路径直接或间接返回该对象;OpenAI 文本生成包装路径默认保持历史兼容形态,只有配置 `openAIDenyResponseFormat: structured` 时才会把该对象嵌入到 OpenAI 响应中。 ```json { + "code": 200, + "denyMessage": "很抱歉,我无法回答您的问题", "blockedDetails": [ { - "Type": "contentModeration", - "Level": "high", - "Suggestion": "block" + "type": "contentModeration", + "level": "high" } - ], - "requestId": "AAAAAA-BBBB-CCCC-DDDD-EEEEEEE****", - "guardCode": 200 + ] } ``` @@ -67,22 +67,26 @@ description: 阿里云内容安全检测 | 字段 | 类型 | 说明 | | --- | --- | --- | -| `blockedDetails` | array | 命中拦截的维度明细;若安全服务未返回明细,则根据顶层风险信号自动合成 | -| `blockedDetails[].Type` | string | 风险类型:`contentModeration` / `promptAttack` / `sensitiveData` / `maliciousUrl` / `modelHallucination` | -| `blockedDetails[].Level` | string | 风险等级:`high` / `medium` / `low` 等 | -| `blockedDetails[].Suggestion` | string | 安全服务建议操作,通常为 `block` | -| `requestId` | string | 安全服务的请求 ID,用于追踪 | -| `guardCode` | int | 安全服务返回的业务码(非 HTTP 状态码,成功检测时为 `200`) | +| `code` | int | 在 `text_generation`(OpenAI 包装) 与 `image_generation` 路径下为网关返回的 HTTP 状态码,取自 `denyCode`(默认 `200`);在 `protocol=original` 与 `mcp` 路径下为安全服务返回的业务码(`Response.Code`,成功检测时为 `200`) | +| `denyMessage` | string | 人类可读拦截文案。OpenAI 包装路径下始终存在,取 `denyMessage`(默认 `很抱歉,我无法回答您的问题`);`protocol=original` / `image_generation` / `mcp` 路径下取 `denyMessage`,未配置时省略该字段(`omitempty`) | +| `blockedDetails` | array | 命中拦截的维度明细;若安全服务未返回 `Detail`,则根据顶层 `RiskLevel`/`AttackLevel` 自动合成。命中维度为空时返回 `[]` | +| `blockedDetails[].type` | string | 风险类型:`contentModeration` / `promptAttack` / `sensitiveData` / `maliciousUrl` / `modelHallucination` / `customLabel` | +| `blockedDetails[].level` | string | 风险等级:`high` / `medium` / `low`;敏感数据为 `S1`–`S4` | + +> 说明:当前实现的拒答 body 仅包含上述字段。不输出安全服务的 `RequestId`、单条 `Suggestion` 与原始业务码(`guardCode`);安全服务的 `RequestId` 通过 AI 日志 `safecheck_request_ids` 字段暴露(见下文 AI Log 章节)。 各协议承载位置: -- **`text_generation`(OpenAI 非流式)**:上述结构体序列化为 JSON 字符串后放入 `choices[0].message.content` -- **`text_generation`(OpenAI 流式 SSE)**:同上,放入首个 chunk 的 `delta.content` -- **`text_generation`(`protocol=original`)**:上述结构体直接作为 JSON 响应 body 返回 +- **`text_generation`(OpenAI,默认 `legacy`)**:不输出 `x_higress_guardrail` 或历史 `x_higress` 字段;`choices[0].message.content` / 首帧 `delta.content` 保持历史内容形态(RiskBlock 为 JSON 字符串,mask fallback 为拒答文案),`finish_reason` 为 `"stop"`,流式响应仍以 `data: [DONE]` 结束 +- **`text_generation`(OpenAI,`structured` 非流式)**:`choices[0].message.content` 承载可读拦截文案(即 `denyMessage`,未配置时默认为 `很抱歉,我无法回答您的问题`);上述结构体作为嵌入对象放入 `choices[0].x_higress_guardrail`(不是 JSON 字符串) +- **`text_generation`(OpenAI,`structured` 流式 SSE)**:首帧 `delta.content` 承载可读拦截文案;上述结构体仅在最后一个 chunk 中作为嵌入对象放入 `choices[0].x_higress_guardrail`,随后以 `data: [DONE]` 结束流 +- **`text_generation`(`protocol=original`)**:上述结构体直接作为 JSON 响应 body 返回(不包 OpenAI 外壳,不新增 `x_higress_guardrail`) - **`image_generation`**:上述结构体直接作为 JSON 响应 body 返回(HTTP 403) - **`mcp`(JSON-RPC)**:上述结构体序列化为 JSON 字符串后放入 `error.message` - **`mcp`(SSE)**:同上,通过 SSE 事件返回 +`openAIDenyResponseFormat` 只影响 OpenAI 包装拒答的 body 形态;拦截判断、fail-open 行为、metric 与 AI Log 字段不随该配置变化。该字段只能配置在插件全局,不能放入 `consumerRiskLevel`。 + 补充说明一下内容合规检测、提示词攻击检测、敏感内容检测三种风险的四个等级: - 对于内容合规检测、提示词攻击检测: @@ -150,6 +154,14 @@ checkRequest: true checkResponse: true ``` +### 配置 OpenAI 结构化拒答 + +默认 `openAIDenyResponseFormat: legacy` 保持历史响应形态。若需要在 OpenAI 响应中输出结构化拦截详情,可配置: + +```yaml +openAIDenyResponseFormat: structured +``` + ### 使用临时安全凭证 ```yaml @@ -247,11 +259,61 @@ ai-security-guard 插件提供了以下监控指标: - `ai_sec_request_deny`: 请求内容安全检测失败请求数 - `ai_sec_response_deny`: 模型回答安全检测失败请求数 +#### 图像响应阶段 metric/ai_log 字段重命名(过渡期) + +历史上图像生成插件(`lvwang/multi_modal_guard/image/openai.go`、`lvwang/multi_modal_guard/image/qwen.go`)在**响应阶段**命中风险时错误地写入了请求阶段字段。本次版本修正了语义,并在 1~2 个发版周期内保留**双写过渡**: + +| 行为 | 旧值(错误,将在后续版本移除) | 新值(推荐) | +| --- | --- | --- | +| 计数器(deny) | `ai_sec_request_deny` | `ai_sec_response_deny` | +| ai_log 耗时(pass + deny) | `safecheck_request_rt` | `safecheck_response_rt` | +| ai_log 状态(deny) | `safecheck_status="reqeust deny"`(典型拼写错误,**本次直接废弃,不再写入**) | `safecheck_status="response deny"` | + +过渡期内图像响应阶段会同时写入新旧两组 `*_deny` 计数器和 `safecheck_*_rt` 字段;`safecheck_status` 只写新值。看板与告警请尽快切换到 `response_*` 字段名;当前依赖 `reqeust deny`(拼写错误版本)状态串的图像响应告警需要立即改为 `response deny`。 + ### Trace 如果开启了链路追踪,ai-security-guard 插件会在请求 span 中添加以下 attributes: - `ai_sec_risklabel`: 表示请求命中的风险类型 - `ai_sec_deny_phase`: 表示请求被检测到风险的阶段(取值为request或者response) +### AI Log +ai-security-guard 插件会将每次提交给内容安全服务的检测结果写入 AI 访问日志,用于将网关日志和阿里云内容安全请求关联起来: + +| 字段 | 类型 | 说明 | +| --- | --- | --- | +| `safecheck_requests` | array | 检测提交事件数组,每个元素为 `{"requestId"?: string, "phase": string, "modality": string, "result": string}` | +| `safecheck_request_ids` | array | 当前网关请求内所有有效内容安全 `RequestId`,按提交完成顺序保留,不去重、不截断 | +| `safecheck_request_id` | string | 最新一个有效内容安全 `RequestId`,用于兼容只读取单值的日志消费方 | +| `safecheck_status` | string | 历史兼容字段,反映本次网关请求最后一次状态变更的语义(详见下方枚举) | +| `safecheck_request_rt` / `safecheck_response_rt` | int | 请求/响应阶段安全检测的耗时(毫秒) | +| `safecheck_riskLabel` / `safecheck_riskWords` | string | 命中风险时的风险标签与风险词(取自安全服务返回的第一个命中结果) | + +`safecheck_requests[].phase` 取值为 `request` 或 `response`;`modality` 取值为 `text`、`image` 或 `mcp`;`result` 表示**该次提交事件本身的处理结果**(而非网关最终对外动作),取值与含义如下: + +| `result` 取值 | 含义 | +| --- | --- | +| `pass` | 该次提交检测通过 | +| `deny` | 该次提交命中风险,网关已对外返回拒答 | +| `mask` | 该次提交命中风险且 `Action=Mask`,安全服务返回了脱敏文本并用于改写请求体 | +| `error` | 该次提交本身处理失败(HTTP 非 200、业务 Code 非 200、反序列化失败、构造拒答响应失败、调用内容安全服务失败等)。错误发生在**响应阶段流式回调**且原因是构造拒答响应失败时,网关会 fail-open(直接放行上游缓冲内容),此时 `safecheck_status=build_fallback_pass`,对应事件 `result=error` 表示这次安全提交未完成 | + +只有安全服务响应中的 `RequestId` 是 JSON 字符串且 `strings.TrimSpace(RequestId) != ""` 时,才会写入 `requestId`、`safecheck_request_ids` 和 `safecheck_request_id`;缺失、空字符串、空白字符串或非字符串值不会写入空占位。 + +每一次提交尝试都会生成一个 `safecheck_requests` 事件,包括 HTTP 非 200、业务失败码以及调用内容安全服务失败等错误场景,错误结果会记录为 `result=error`。需要精确审计多次提交、流式分段或图片多次检测时,应优先使用 `safecheck_requests`。 + +`safecheck_status` 枚举(历史字段,按"最后一次状态变更"覆盖,存在多次提交时仅保留最后一次的语义): + +| `safecheck_status` 取值 | 含义 | +| --- | --- | +| `request pass` | 请求阶段所有提交均通过 | +| `request mask` | 请求阶段命中 mask,请求体已被脱敏文本改写 | +| `reqeust deny` | 请求阶段命中风险,网关返回拒答(注:拼写为 `reqeust`,沿用历史,保持向后兼容) | +| `request error` | 请求阶段安全提交本身失败(HTTP/反序列化/调用安全服务等),网关 fail-open 放行 | +| `response pass` | 响应阶段所有提交均通过 | +| `response deny` | 响应阶段命中风险,网关返回拒答 | +| `response error` | 响应阶段安全提交本身失败,网关 fail-open 放行 | +| `build_fallback_pass` | 响应阶段流式回调里构造拒答响应失败,网关 fail-open 直接放行上游缓冲内容 | + ## 请求示例 ```bash curl http://localhost/v1/chat/completions \ @@ -267,25 +329,39 @@ curl http://localhost/v1/chat/completions \ }' ``` -请求内容会被发送到阿里云内容安全服务进行检测,如果请求内容检测结果为非法,网关将返回形如以下的回答: +当配置 `openAIDenyResponseFormat: structured` 时,请求内容会被发送到阿里云内容安全服务进行检测。如果请求内容检测结果为非法,网关将返回形如以下的回答: ```json { "id": "chatcmpl-AAy3hK1dE4ODaegbGOMoC9VY4Sizv", "object": "chat.completion", - "created": 1677652288, - "model": "gpt-4o-mini", - "system_fingerprint": "fp_44709d6fcb", + "created": 1727078400, + "model": "from-security-guard", "choices": [ { "index": 0, "message": { "role": "assistant", - "content": "作为一名人工智能助手,我不能提供涉及色情、暴力、政治等敏感话题的内容。如果您有其他相关问题,欢迎您提问。", + "content": "作为一名人工智能助手,我不能提供涉及色情、暴力、政治等敏感话题的内容。如果您有其他相关问题,欢迎您提问。" }, "logprobs": null, - "finish_reason": "stop" + "finish_reason": "stop", + "x_higress_guardrail": { + "code": 200, + "denyMessage": "作为一名人工智能助手,我不能提供涉及色情、暴力、政治等敏感话题的内容。如果您有其他相关问题,欢迎您提问。", + "blockedDetails": [ + { + "type": "contentModeration", + "level": "high" + } + ] + } } - ] + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } } ``` diff --git a/plugins/wasm-go/extensions/ai-security-guard/README_EN.md b/plugins/wasm-go/extensions/ai-security-guard/README_EN.md index 08fa672b8..0127c55a0 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README_EN.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README_EN.md @@ -34,6 +34,7 @@ Plugin Priority: `300` | `denyCode` | int | optional | 200 | Response status code when the specified content is illegal | | `denyMessage` | string | optional | Drainage/non-streaming response in openai format, the answer content is the suggested answer from Alibaba Cloud content security | Response content when the specified content is illegal | | `protocol` | string | optional | openai | protocol format, `openai` or `original` | +| `openAIDenyResponseFormat` | string | optional | legacy | OpenAI-wrapped deny response format, `legacy` or `structured`. The default `legacy` preserves historical compatibility; `structured` embeds blocking details at `choices[0].x_higress_guardrail` | | `contentModerationLevelBar` | string | optional | max | contentModeration risk level threshold, `max`, `high`, `medium` or `low` | | `promptAttackLevelBar` | string | optional | max | promptAttack risk level threshold, `max`, `high`, `medium` or `low` | | `sensitiveDataLevelBar` | string | optional | S4 | sensitiveData risk level threshold, `S4`, `S3`, `S2` or `S1` | @@ -71,19 +72,18 @@ Risk level explanations for each detection dimension: ### Deny Response Body -When content is blocked, the plugin (`MultiModalGuard` action) returns the following structured JSON object. The location in the response depends on the protocol: +When content is blocked, the plugin (`MultiModalGuard` action) builds the following structured JSON object. `protocol: original`, MCP, and image-generation paths return it directly or indirectly; OpenAI text-generation wrapping keeps the historical response shape by default, and embeds this object only when `openAIDenyResponseFormat: structured` is configured. ```json { + "code": 200, + "denyMessage": "Sorry, I cannot answer your question.", "blockedDetails": [ { - "Type": "contentModeration", - "Level": "high", - "Suggestion": "block" + "type": "contentModeration", + "level": "high" } - ], - "requestId": "AAAAAA-BBBB-CCCC-DDDD-EEEEEEE****", - "guardCode": 200 + ] } ``` @@ -91,22 +91,26 @@ Field descriptions: | Field | Type | Description | | --- | --- | --- | -| `blockedDetails` | array | Details of the triggered blocking dimensions. Synthesised from top-level risk signals when the security service returns no detail entries. | -| `blockedDetails[].Type` | string | Risk type: `contentModeration` / `promptAttack` / `sensitiveData` / `maliciousUrl` / `modelHallucination` | -| `blockedDetails[].Level` | string | Risk level: `high` / `medium` / `low` etc. | -| `blockedDetails[].Suggestion` | string | Action recommended by the security service, usually `block` | -| `requestId` | string | Request ID from the security service, for tracing | -| `guardCode` | int | Business code returned by the security service (not an HTTP status code; `200` indicates a successful check that detected a risk) | +| `code` | int | For `text_generation` (OpenAI wrapping) and `image_generation` paths, this is the HTTP status the gateway returns, sourced from `denyCode` (default `200`). For `protocol=original` and `mcp` paths, this is the business code returned by the security service (`Response.Code`; `200` indicates a successful check that detected a risk). | +| `denyMessage` | string | Human-readable deny text. Always present on OpenAI-wrapping paths, taken from `denyMessage` (defaults to `Sorry, I cannot answer your question.`). On `protocol=original` / `image_generation` / `mcp` paths the value is taken from `denyMessage` and omitted (`omitempty`) when unconfigured. | +| `blockedDetails` | array | Details of the triggered blocking dimensions. Synthesised from top-level `RiskLevel`/`AttackLevel` when the security service returns no `Detail` entries. Returns `[]` when no dimension is hit. | +| `blockedDetails[].type` | string | Risk type: `contentModeration` / `promptAttack` / `sensitiveData` / `maliciousUrl` / `modelHallucination` / `customLabel` | +| `blockedDetails[].level` | string | Risk level: `high` / `medium` / `low`; for sensitive data: `S1`–`S4` | + +> Note: the current implementation emits only the fields above. The security service's `RequestId`, per-detail `Suggestion`, and raw business code (`guardCode`) are not embedded in the deny body. The security service's `RequestId` is exposed via the AI access log field `safecheck_request_ids` (see the AI Log section below). How the body is embedded per protocol: -- **`text_generation` (OpenAI non-streaming)**: serialised as a JSON string and placed in `choices[0].message.content` -- **`text_generation` (OpenAI streaming SSE)**: same, placed in `delta.content` of the first chunk -- **`text_generation` (`protocol=original`)**: returned directly as the JSON response body +- **`text_generation` (OpenAI, default `legacy`)**: emits neither `x_higress_guardrail` nor the historical `x_higress` field; `choices[0].message.content` / the first `delta.content` frame keeps the historical content shape (a JSON string for RiskBlock, deny text for mask fallback), `finish_reason` is `"stop"`, and streaming responses still end with `data: [DONE]` +- **`text_generation` (OpenAI, `structured` non-streaming)**: `choices[0].message.content` carries the human-readable deny text (`denyMessage`, defaults to `Sorry, I cannot answer your question.` when unconfigured); the structure above is placed at `choices[0].x_higress_guardrail` as an embedded object (not a JSON string) +- **`text_generation` (OpenAI, `structured` streaming SSE)**: the first frame's `delta.content` carries the human-readable deny text; the structure above is attached only to the last chunk at `choices[0].x_higress_guardrail` as an embedded object, followed by `data: [DONE]` +- **`text_generation` (`protocol=original`)**: returned directly as the JSON response body (no OpenAI wrapper, no `x_higress_guardrail`) - **`image_generation`**: returned directly as the JSON response body (HTTP 403) - **`mcp` (JSON-RPC)**: serialised as a JSON string and placed in `error.message` - **`mcp` (SSE)**: same, returned via SSE event +`openAIDenyResponseFormat` only changes the OpenAI-wrapped deny body shape; blocking decisions, fail-open behavior, metrics, and AI Log fields do not vary by format. Configure this field only at plugin global scope, not under `consumerRiskLevel`. + ## Examples of configuration ### Check if the input is legal @@ -131,6 +135,14 @@ checkRequest: true checkResponse: true ``` +### Configure OpenAI Structured Deny Responses + +The default `openAIDenyResponseFormat: legacy` keeps the historical response shape. To emit structured blocking details in OpenAI responses, configure: + +```yaml +openAIDenyResponseFormat: structured +``` + ### Configure response fallback extraction paths When primary extraction paths are empty, you can configure ordered fallback paths to support multiple response formats: @@ -165,7 +177,57 @@ ai-security-guard plugin provides following metrics: - `ai_sec_request_deny`: count of requests denied at request phase - `ai_sec_response_deny`: count of requests denied at response phase +#### Image response-phase metric / ai_log rename (transition window) + +The image generation handlers (`lvwang/multi_modal_guard/image/openai.go` and `lvwang/multi_modal_guard/image/qwen.go`) historically emitted request-phase field names for **response-phase** events. This release corrects the semantics and keeps a **double-write transition** for 1–2 release cycles: + +| Signal | Legacy value (wrong; removed in a future release) | New value (recommended) | +| --- | --- | --- | +| Counter (deny) | `ai_sec_request_deny` | `ai_sec_response_deny` | +| ai_log latency (pass + deny) | `safecheck_request_rt` | `safecheck_response_rt` | +| ai_log status (deny) | `safecheck_status="reqeust deny"` (typo; **dropped immediately, no longer emitted**) | `safecheck_status="response deny"` | + +During the transition window, the image response phase emits both the new and the legacy `*_deny` counters and `safecheck_*_rt` attributes; `safecheck_status` only emits the new value. Migrate dashboards / alerts to the `response_*` names; any image-response alert that still keys off the typo'd `reqeust deny` status string must move to `response deny` immediately. + ### Trace ai-security-guard plugin provides following span attributes: - `ai_sec_risklabel`: risk type of this request -- `ai_sec_deny_phase`: denied phase of this request, value can be request/response \ No newline at end of file +- `ai_sec_deny_phase`: denied phase of this request, value can be request/response + +### AI Log +ai-security-guard writes each submission to the content security service into the AI access log, so gateway logs can be correlated with Alibaba Cloud content security requests: + +| Field | Type | Description | +| --- | --- | --- | +| `safecheck_requests` | array | Submission event array. Each item is `{"requestId"?: string, "phase": string, "modality": string, "result": string}` | +| `safecheck_request_ids` | array | All valid content security `RequestId` values for the current gateway request, preserved in submission completion order without deduplication or truncation | +| `safecheck_request_id` | string | The latest valid content security `RequestId`, kept for consumers that only read a single value | +| `safecheck_status` | string | Legacy compatibility field reflecting the last status transition for this gateway request (see enum below) | +| `safecheck_request_rt` / `safecheck_response_rt` | int | Latency (ms) of the security check during the request / response phase | +| `safecheck_riskLabel` / `safecheck_riskWords` | string | Risk label and risk words when a risk is hit (taken from the first result returned by the security service) | + +`safecheck_requests[].phase` is `request` or `response`; `modality` is `text`, `image`, or `mcp`; `result` describes **the processing outcome of that submission event itself** (not the gateway's final outbound action). Values: + +| `result` value | Meaning | +| --- | --- | +| `pass` | The submission passed the check | +| `deny` | The submission hit a risk; the gateway returned a deny response | +| `mask` | The submission hit a risk with `Action=Mask`; the security service returned desensitized text and the request body was rewritten | +| `error` | The submission itself failed (HTTP non-200, business `Code` non-200, unmarshal failure, deny-response build failure, dispatch failure, etc.). When the failure occurs in the **streaming response callback** because building the deny response failed, the gateway fails open (injects buffered upstream content as-is); in that case `safecheck_status=build_fallback_pass` and the corresponding event has `result=error` to indicate the security submission did not complete | + +The plugin writes `requestId`, `safecheck_request_ids`, and `safecheck_request_id` only when the security service response contains a JSON string `RequestId` and `strings.TrimSpace(RequestId) != ""`; missing, empty, whitespace-only, or non-string values do not produce empty placeholders. + +Every submission attempt emits one `safecheck_requests` event, including HTTP non-200 responses, business failures, and failures to dispatch the security service call. These error paths are recorded as `result=error`. Use `safecheck_requests` for precise auditing across multiple submissions, streaming chunks, or multiple image checks. + +`safecheck_status` enum (legacy field; overwritten on each status transition, so only the last transition's value is preserved when there are multiple submissions): + +| `safecheck_status` value | Meaning | +| --- | --- | +| `request pass` | All request-phase submissions passed | +| `request mask` | A request-phase submission hit mask; the request body was rewritten with desensitized text | +| `reqeust deny` | A request-phase submission hit a risk; the gateway returned a deny response (note: typo `reqeust` is preserved for backward compatibility) | +| `request error` | A request-phase security submission itself failed (HTTP / unmarshal / dispatch / etc.); the gateway fails open | +| `response pass` | All response-phase submissions passed | +| `response deny` | A response-phase submission hit a risk; the gateway returned a deny response | +| `response error` | A response-phase security submission itself failed; the gateway fails open | +| `build_fallback_pass` | In the streaming response callback, building the deny response failed; the gateway fails open and injects the buffered upstream content as-is | diff --git a/plugins/wasm-go/extensions/ai-security-guard/ai_log_test.go b/plugins/wasm-go/extensions/ai-security-guard/ai_log_test.go new file mode 100644 index 000000000..e69cc8bcc --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/ai_log_test.go @@ -0,0 +1,454 @@ +package main + +import ( + "encoding/json" + "testing" + + cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/iface" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +type aiLogSnapshot struct { + SafecheckRequests []cfg.GuardrailSubmissionEvent `json:"safecheck_requests"` + SafecheckRequestIDs []string `json:"safecheck_request_ids"` + SafecheckRequestID string `json:"safecheck_request_id"` + SafecheckStatus string `json:"safecheck_status"` +} + +func readAILogSnapshot(t *testing.T, host test.TestHost) (aiLogSnapshot, string) { + t.Helper() + raw, err := host.GetProperty([]string{wrapper.AILogKey}) + require.NoError(t, err) + decoded := wrapper.UnmarshalStr(`"` + string(raw) + `"`) + require.NotEmpty(t, decoded) + + var snapshot aiLogSnapshot + require.NoError(t, json.Unmarshal([]byte(decoded), &snapshot)) + return snapshot, decoded +} + +func requireAILogArraySchema(t *testing.T, raw string) { + t.Helper() + require.True(t, gjson.Get(raw, cfg.SafecheckRequestsKey).IsArray(), "safecheck_requests must be a JSON array") + require.True(t, gjson.Get(raw, cfg.SafecheckRequestIDsKey).IsArray(), "safecheck_request_ids must be a JSON array") +} + +func requireSafecheckEvent(t *testing.T, event cfg.GuardrailSubmissionEvent, phase, modality, result, requestID string) { + t.Helper() + require.Equal(t, phase, event.Phase) + require.Equal(t, modality, event.Modality) + require.Equal(t, result, event.Result) + require.Equal(t, requestID, event.RequestID) +} + +func TestGuardrailAILogRequestAndResponseEventSchema(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("request pass emits one structured text event", func(t *testing.T) { + host, status := test.NewTestHost(multiModalGuardTextConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "Hello"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-structured-pass", "Data": {"RiskLevel": "low"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + snapshot, raw := readAILogSnapshot(t, host) + requireAILogArraySchema(t, raw) + require.Len(t, snapshot.SafecheckRequests, 1) + requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText, cfg.GuardrailResultPass, "req-structured-pass") + require.Equal(t, []string{"req-structured-pass"}, snapshot.SafecheckRequestIDs) + require.Equal(t, "req-structured-pass", snapshot.SafecheckRequestID) + require.Equal(t, "request pass", snapshot.SafecheckStatus) + }) + + t.Run("response deny emits one structured text event", func(t *testing.T) { + host, status := test.NewTestHost(multiModalGuardTextConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }) + + body := `{"choices": [{"message": {"role": "assistant", "content": "bad response content"}}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpResponseBody([]byte(body))) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-structured-deny", "Data": {"RiskLevel": "high"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + snapshot, raw := readAILogSnapshot(t, host) + requireAILogArraySchema(t, raw) + require.Len(t, snapshot.SafecheckRequests, 1) + requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseResponse, cfg.GuardrailModalityText, cfg.GuardrailResultDeny, "req-structured-deny") + require.Equal(t, []string{"req-structured-deny"}, snapshot.SafecheckRequestIDs) + require.Equal(t, "req-structured-deny", snapshot.SafecheckRequestID) + require.Equal(t, "response deny", snapshot.SafecheckStatus) + }) + }) +} + +func TestGuardrailAILogStreamingPassFlushesBeforeEOS(t *testing.T) { + streamingFlushConfig := func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": false, + "checkResponse": true, + "action": "MultiModalGuard", + "apiType": "text_generation", + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "sensitiveDataLevelBar": "S3", + "timeout": 2000, + "bufferLimit": 1, + }) + return data + }() + + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(streamingFlushConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "text/event-stream"}, + }) + + chunk := []byte("data:{\"id\":\"chatcmpl-1\",\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\n") + host.CallOnHttpStreamingResponseBody(chunk, false) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-stream-pass", "Data": {"RiskLevel": "low"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + snapshot, raw := readAILogSnapshot(t, host) + requireAILogArraySchema(t, raw) + require.Len(t, snapshot.SafecheckRequests, 1) + requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseResponse, cfg.GuardrailModalityText, cfg.GuardrailResultPass, "req-stream-pass") + require.False(t, gjson.Get(raw, "safecheck_status").Exists(), "event-level flush should not wait for a terminal safecheck_status") + }) +} + +func TestGuardrailAILogErrorFlushAndOrderingForImageSubmissions(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + runCase := func(t *testing.T, firstHeaders [][2]string, firstResponse, firstRequestID string) { + host, status := test.NewTestHost(multiModalGuardImageQwenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/generations"}, + {":method", "POST"}, + }) + + body := `{"input": {"images": ["https://example.com/a.png", "https://example.com/b.png"]}}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + host.CallOnHttpCall(firstHeaders, []byte(firstResponse)) + snapshot, raw := readAILogSnapshot(t, host) + requireAILogArraySchema(t, raw) + require.Len(t, snapshot.SafecheckRequests, 1) + requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityImage, cfg.GuardrailResultError, firstRequestID) + require.Equal(t, []string{firstRequestID}, snapshot.SafecheckRequestIDs) + + secondResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-image-pass", "Data": {"RiskLevel": "low"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(secondResponse)) + + snapshot, raw = readAILogSnapshot(t, host) + requireAILogArraySchema(t, raw) + require.Len(t, snapshot.SafecheckRequests, 2) + requireSafecheckEvent(t, snapshot.SafecheckRequests[1], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityImage, cfg.GuardrailResultPass, "req-image-pass") + require.Equal(t, []string{firstRequestID, "req-image-pass"}, snapshot.SafecheckRequestIDs) + require.Equal(t, "req-image-pass", snapshot.SafecheckRequestID) + } + + t.Run("non-200 HTTP response flushes error before next image submission", func(t *testing.T) { + runCase(t, [][2]string{ + {":status", "502"}, + {"content-type", "application/json"}, + }, `{"RequestId": "req-http-error"}`, "req-http-error") + }) + + t.Run("business failure flushes error before next image submission", func(t *testing.T) { + runCase(t, [][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, `{"Code": 500, "Message": "Failed", "RequestId": "req-business-error"}`, "req-business-error") + }) + }) +} + +func TestGuardrailAILogMalformedRequestIDsAreIgnored(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + cases := []struct { + name string + response string + expectedResult string + }{ + { + name: "missing", + response: `{"Code": 200, "Message": "Success", "Data": {"RiskLevel": "low"}}`, + expectedResult: cfg.GuardrailResultPass, + }, + { + name: "empty", + response: `{"Code": 200, "Message": "Success", "RequestId": "", "Data": {"RiskLevel": "low"}}`, + expectedResult: cfg.GuardrailResultPass, + }, + { + name: "whitespace", + response: `{"Code": 200, "Message": "Success", "RequestId": " ", "Data": {"RiskLevel": "low"}}`, + expectedResult: cfg.GuardrailResultPass, + }, + { + name: "non-string", + response: `{"Code": 200, "Message": "Success", "RequestId": 123, "Data": {"RiskLevel": "low"}}`, + expectedResult: cfg.GuardrailResultError, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + host, status := test.NewTestHost(multiModalGuardTextConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "Hello"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(tc.response)) + + snapshot, raw := readAILogSnapshot(t, host) + requireAILogArraySchema(t, raw) + require.Len(t, snapshot.SafecheckRequests, 1) + requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText, tc.expectedResult, "") + require.Empty(t, snapshot.SafecheckRequestIDs) + require.False(t, gjson.Get(raw, cfg.SafecheckRequestIDKey).Exists()) + require.False(t, gjson.Get(raw, cfg.SafecheckRequestsKey+".0.requestId").Exists()) + }) + } + }) +} + +func TestGuardrailAILogMaskFallbackRecordsDeny(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(maskConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "敏感内容"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + securityResponse := `{ + "Code": 200, "Message": "Success", "RequestId": "req-mask-fallback", + "Data": { + "RiskLevel": "none", + "Detail": [{ + "Suggestion": "mask", "Type": "sensitiveData", "Level": "S3", + "Result": [{"Label": "phone", "Confidence": 99.0, + "Ext": {"Desensitization": ""}}] + }] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + snapshot, raw := readAILogSnapshot(t, host) + requireAILogArraySchema(t, raw) + require.Len(t, snapshot.SafecheckRequests, 1) + requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText, cfg.GuardrailResultDeny, "req-mask-fallback") + require.Equal(t, []string{"req-mask-fallback"}, snapshot.SafecheckRequestIDs) + }) +} + +func TestGuardrailAILogDispatchFailureEmitsErrorEvent(t *testing.T) { + ctx := newStubHTTPContext() + eventIndex := cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText) + cfg.CompleteGuardrailSubmissionEventWithRequestID(ctx, eventIndex, "", cfg.GuardrailResultError) + cfg.WriteGuardrailLog(ctx) + + events, ok := ctx.GetUserAttribute(cfg.SafecheckRequestsKey).([]cfg.GuardrailSubmissionEvent) + require.True(t, ok) + require.Len(t, events, 1) + requireSafecheckEvent(t, events[0], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText, cfg.GuardrailResultError, "") + + requestIDs, ok := ctx.GetUserAttribute(cfg.SafecheckRequestIDsKey).([]string) + require.True(t, ok) + require.Empty(t, requestIDs) + require.Nil(t, ctx.GetUserAttribute(cfg.SafecheckRequestIDKey)) + require.Equal(t, []string{wrapper.AILogKey}, ctx.writes) +} + +type stubHTTPContext struct { + userContext map[string]interface{} + userAttribute map[string]interface{} + bufferQueue [][]byte + writes []string + routeCallError error +} + +func newStubHTTPContext() *stubHTTPContext { + return &stubHTTPContext{ + userContext: map[string]interface{}{}, + userAttribute: map[string]interface{}{}, + } +} + +func (ctx *stubHTTPContext) Scheme() string { return "" } +func (ctx *stubHTTPContext) Host() string { return "" } +func (ctx *stubHTTPContext) Path() string { return "" } +func (ctx *stubHTTPContext) Method() string { return "" } + +func (ctx *stubHTTPContext) SetContext(key string, value interface{}) { + ctx.userContext[key] = value +} + +func (ctx *stubHTTPContext) GetContext(key string) interface{} { + return ctx.userContext[key] +} + +func (ctx *stubHTTPContext) GetBoolContext(key string, defaultValue bool) bool { + if value, ok := ctx.userContext[key].(bool); ok { + return value + } + return defaultValue +} + +func (ctx *stubHTTPContext) GetStringContext(key, defaultValue string) string { + if value, ok := ctx.userContext[key].(string); ok { + return value + } + return defaultValue +} + +func (ctx *stubHTTPContext) GetByteSliceContext(key string, defaultValue []byte) []byte { + if value, ok := ctx.userContext[key].([]byte); ok { + return value + } + return defaultValue +} + +func (ctx *stubHTTPContext) GetUserAttribute(key string) interface{} { + return ctx.userAttribute[key] +} + +func (ctx *stubHTTPContext) SetUserAttribute(key string, value interface{}) { + ctx.userAttribute[key] = value +} + +func (ctx *stubHTTPContext) SetUserAttributeMap(kvmap map[string]interface{}) { + ctx.userAttribute = kvmap +} + +func (ctx *stubHTTPContext) GetUserAttributeMap() map[string]interface{} { + return ctx.userAttribute +} + +func (ctx *stubHTTPContext) WriteUserAttributeToLog() error { + return ctx.WriteUserAttributeToLogWithKey(wrapper.CustomLogKey) +} + +func (ctx *stubHTTPContext) WriteUserAttributeToLogWithKey(key string) error { + ctx.writes = append(ctx.writes, key) + return nil +} + +func (ctx *stubHTTPContext) WriteUserAttributeToTrace() error { return nil } +func (ctx *stubHTTPContext) DontReadRequestBody() {} +func (ctx *stubHTTPContext) DontReadResponseBody() {} +func (ctx *stubHTTPContext) BufferRequestBody() {} +func (ctx *stubHTTPContext) BufferResponseBody() {} +func (ctx *stubHTTPContext) NeedPauseStreamingResponse() {} + +func (ctx *stubHTTPContext) PushBuffer(buffer []byte) { + ctx.bufferQueue = append(ctx.bufferQueue, buffer) +} + +func (ctx *stubHTTPContext) PopBuffer() []byte { + if len(ctx.bufferQueue) == 0 { + return nil + } + buffer := ctx.bufferQueue[0] + ctx.bufferQueue = ctx.bufferQueue[1:] + return buffer +} + +func (ctx *stubHTTPContext) BufferQueueSize() int { return len(ctx.bufferQueue) } +func (ctx *stubHTTPContext) DisableReroute() {} +func (ctx *stubHTTPContext) SetRequestBodyBufferLimit(uint32) { +} +func (ctx *stubHTTPContext) SetResponseBodyBufferLimit(uint32) { +} + +func (ctx *stubHTTPContext) RouteCall(string, string, [][2]string, []byte, iface.RouteResponseCallback) error { + return ctx.routeCallError +} + +func (ctx *stubHTTPContext) GetExecutionPhase() iface.HTTPExecutionPhase { + return iface.DecodeData +} + +func (ctx *stubHTTPContext) HasRequestBody() bool { return true } +func (ctx *stubHTTPContext) HasResponseBody() bool { return true } +func (ctx *stubHTTPContext) IsWebsocket() bool { return false } +func (ctx *stubHTTPContext) IsBinaryRequestBody() bool { return false } +func (ctx *stubHTTPContext) IsBinaryResponseBody() bool { + return false +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/config/ai_log.go b/plugins/wasm-go/extensions/ai-security-guard/config/ai_log.go new file mode 100644 index 000000000..32219c6de --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/config/ai_log.go @@ -0,0 +1,126 @@ +package config + +import ( + "strings" + "time" + + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + SafecheckRequestsKey = "safecheck_requests" + SafecheckRequestIDsKey = "safecheck_request_ids" + SafecheckRequestIDKey = "safecheck_request_id" + + GuardrailPhaseRequest = "request" + GuardrailPhaseResponse = "response" + + GuardrailModalityText = "text" + GuardrailModalityImage = "image" + GuardrailModalityMCP = "mcp" + + GuardrailResultPass = "pass" + GuardrailResultDeny = "deny" + GuardrailResultMask = "mask" + GuardrailResultError = "error" +) + +type GuardrailSubmissionEvent struct { + RequestID string `json:"requestId,omitempty"` + Phase string `json:"phase"` + Modality string `json:"modality"` + Result string `json:"result"` +} + +// BeginGuardrailSubmissionEvent appends a placeholder event so append order matches +// the current serial submission order. The event is flushed only after completion. +func BeginGuardrailSubmissionEvent(ctx wrapper.HttpContext, phase, modality string) int { + events := getGuardrailSubmissionEvents(ctx) + events = append(events, GuardrailSubmissionEvent{ + Phase: phase, + Modality: modality, + }) + setGuardrailSubmissionEvents(ctx, events) + return len(events) - 1 +} + +func CompleteGuardrailSubmissionEvent(ctx wrapper.HttpContext, index int, responseBody []byte, result string) { + CompleteGuardrailSubmissionEventWithRequestID(ctx, index, ExtractValidRequestID(responseBody), result) +} + +func CompleteGuardrailSubmissionEventWithRequestID(ctx wrapper.HttpContext, index int, requestID, result string) { + events := getGuardrailSubmissionEvents(ctx) + if index < 0 || index >= len(events) { + return + } + events[index].Result = result + if requestID != "" { + events[index].RequestID = requestID + } + setGuardrailSubmissionEvents(ctx, events) +} + +// WriteGuardrailLog writes current guardrail-related user attributes to the AI log. +// Call after submission events are updated; Complete* does not flush the log. +func WriteGuardrailLog(ctx wrapper.HttpContext) { + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) +} + +// MarkGuardrailRequestError finalizes a request-phase safecheck submission that failed +// upstream (HTTP, unmarshal, downstream build, or dispatch error). It overwrites the +// legacy safecheck_status with "request error" so consumers that only watch the single +// status field do not see a stale "request pass" left over from a prior chunk, and +// records safecheck_request_rt so latency metrics cover failures too. +func MarkGuardrailRequestError(ctx wrapper.HttpContext, index int, responseBody []byte, startTime int64) { + ctx.SetUserAttribute("safecheck_request_rt", time.Now().UnixMilli()-startTime) + ctx.SetUserAttribute("safecheck_status", "request error") + CompleteGuardrailSubmissionEvent(ctx, index, responseBody, GuardrailResultError) + WriteGuardrailLog(ctx) +} + +// MarkGuardrailResponseError is the response-phase counterpart of MarkGuardrailRequestError. +func MarkGuardrailResponseError(ctx wrapper.HttpContext, index int, responseBody []byte, startTime int64) { + ctx.SetUserAttribute("safecheck_response_rt", time.Now().UnixMilli()-startTime) + ctx.SetUserAttribute("safecheck_status", "response error") + CompleteGuardrailSubmissionEvent(ctx, index, responseBody, GuardrailResultError) + WriteGuardrailLog(ctx) +} + +func ExtractValidRequestID(responseBody []byte) string { + if len(responseBody) == 0 { + return "" + } + requestID := gjson.GetBytes(responseBody, "RequestId") + if !requestID.Exists() || requestID.Type != gjson.String { + return "" + } + trimmed := strings.TrimSpace(requestID.String()) + if trimmed == "" { + return "" + } + return trimmed +} + +func getGuardrailSubmissionEvents(ctx wrapper.HttpContext) []GuardrailSubmissionEvent { + events, ok := ctx.GetUserAttribute(SafecheckRequestsKey).([]GuardrailSubmissionEvent) + if !ok || events == nil { + return []GuardrailSubmissionEvent{} + } + return events +} + +func setGuardrailSubmissionEvents(ctx wrapper.HttpContext, events []GuardrailSubmissionEvent) { + ctx.SetUserAttribute(SafecheckRequestsKey, events) + + requestIDs := make([]string, 0, len(events)) + for _, event := range events { + if event.RequestID != "" { + requestIDs = append(requestIDs, event.RequestID) + } + } + ctx.SetUserAttribute(SafecheckRequestIDsKey, requestIDs) + if len(requestIDs) > 0 { + ctx.SetUserAttribute(SafecheckRequestIDKey, requestIDs[len(requestIDs)-1]) + } +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/config/config.go b/plugins/wasm-go/extensions/ai-security-guard/config/config.go index e3c9fb252..9f5ac09e1 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/config/config.go +++ b/plugins/wasm-go/extensions/ai-security-guard/config/config.go @@ -6,12 +6,16 @@ import ( "fmt" "regexp" "strings" + "time" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" ) +type OpenAIDenyResponseFormat string + const ( MaxRisk = "max" HighRisk = "high" @@ -35,10 +39,31 @@ const ( WaterMarkType = "waterMark" // Default configurations - OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` - OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}` - OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` - OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]` + // Template parameter order: + // OpenAIResponseFormatLegacy: id, created (unix sec), content + // OpenAIResponseFormatStructured: id, created (unix sec), content, x_higress_guardrail JSON + // OpenAIStreamResponseChunk: id, created, content + // OpenAIStreamResponseEndLegacy: id, created + // OpenAIStreamResponseEndStructured: id, created, x_higress_guardrail JSON + // OpenAIStreamResponseFormatLegacy: id, created, content, id, created + // OpenAIStreamResponseFormatStructured: id, created, content, id, created, x_higress_guardrail JSON + // `created` is required by openai-python (ChatCompletion.created is non-Optional). + // `finish_reason: "stop"` preserves wire-level compatibility with downstream + // consumers (LangChain / LiteLLM / SDKs / BI dashboards) that treat `stop` as + // "valid completion"; the moderation-event signal lives in the nested + // `choices[0].x_higress_guardrail` block (denyCode / blockedDetails) instead. + OpenAIResponseFormatLegacy = `{"id":"%s","object":"chat.completion","created":%d,"model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + OpenAIResponseFormatStructured = `{"id":"%s","object":"chat.completion","created":%d,"model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop","x_higress_guardrail":%s}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}` + OpenAIStreamResponseEndLegacy = `data:{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + OpenAIStreamResponseEndStructured = `data:{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop","x_higress_guardrail":%s}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + OpenAIStreamResponseFormatLegacy = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEndLegacy + "\n\n" + `data: [DONE]` + OpenAIStreamResponseFormatStructured = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEndStructured + "\n\n" + `data: [DONE]` + + OpenAIDenyResponseFormatLegacy OpenAIDenyResponseFormat = "legacy" + OpenAIDenyResponseFormatStructured OpenAIDenyResponseFormat = "structured" + + OpenAIDenyResponseFormatConsumerScopeError = "openAIDenyResponseFormat must be configured at plugin global scope, not under consumerRiskLevel" DefaultDenyCode = 200 DefaultDenyMessage = "很抱歉,我无法回答您的问题" @@ -184,6 +209,7 @@ type AISecurityConfig struct { DenyCode int64 DenyMessage string ProtocolOriginal bool + OpenAIDenyResponseFormat OpenAIDenyResponseFormat RiskLevelBar string ContentModerationLevelBar string PromptAttackLevelBar string @@ -296,6 +322,16 @@ func (config *AISecurityConfig) Parse(json gjson.Result) error { config.CheckRequestImage = json.Get("checkRequestImage").Bool() config.CheckResponse = json.Get("checkResponse").Bool() config.ProtocolOriginal = json.Get("protocol").String() == "original" + if obj := json.Get("openAIDenyResponseFormat"); obj.Exists() { + switch OpenAIDenyResponseFormat(obj.String()) { + case OpenAIDenyResponseFormatLegacy: + config.OpenAIDenyResponseFormat = OpenAIDenyResponseFormatLegacy + case OpenAIDenyResponseFormatStructured: + config.OpenAIDenyResponseFormat = OpenAIDenyResponseFormatStructured + default: + return errors.New("invalid openAIDenyResponseFormat, value must be one of [legacy, structured]") + } + } config.DenyMessage = json.Get("denyMessage").String() if obj := json.Get("denyCode"); obj.Exists() { config.DenyCode = obj.Int() @@ -411,6 +447,9 @@ func (config *AISecurityConfig) Parse(json gjson.Result) error { for k, v := range item.Map() { m[k] = v.Value() } + if _, ok := m["openAIDenyResponseFormat"]; ok { + return errors.New(OpenAIDenyResponseFormatConsumerScopeError) + } consumerName, ok1 := m["name"] matchType, ok2 := m["matchType"] if !ok1 || !ok2 { @@ -531,6 +570,7 @@ func (config *AISecurityConfig) SetDefaultValues() { config.ApiType = ApiTextGeneration config.ProviderType = ProviderOpenAI config.RiskAction = "block" + config.OpenAIDenyResponseFormat = OpenAIDenyResponseFormatLegacy } func (config *AISecurityConfig) IncrementCounter(metricName string, inc uint64) { @@ -951,6 +991,151 @@ func BuildDenyResponseBody(response Response, config AISecurityConfig, consumer return json.Marshal(body) } +// ResolveDenyMessage returns the human-readable deny text used both for +// non-original OpenAI wrappers (message.content / delta.content) and for the +// x_higress_guardrail.denyMessage field, ensuring the two stay aligned. +func ResolveDenyMessage(config AISecurityConfig) string { + if config.DenyMessage != "" { + return config.DenyMessage + } + return DefaultDenyMessage +} + +// BuildOpenAIDenyResponseBody builds the guardrail JSON object embedded by the +// outer OpenAI structured template as choices[0].x_higress_guardrail. Its shape +// mirrors DenyResponseBody, but DenyMessage is filled via ResolveDenyMessage so +// the field is always present and consistent with the rendered content. +// Code is sourced from config.DenyCode so that x_higress_guardrail.code +// consistently represents "the HTTP status this gateway returns to the client", +// aligned with BuildOpenAIFallbackDenyResponseBody. +func BuildOpenAIDenyResponseBody(response Response, config AISecurityConfig, consumer string) ([]byte, error) { + details := GetUnacceptableDetail(response.Data, config, consumer) + blocked := make([]BlockedDetail, 0, len(details)) + for _, d := range details { + blocked = append(blocked, BlockedDetail{ + Type: d.Type, + Level: d.Level, + }) + } + body := DenyResponseBody{ + Code: int(config.DenyCode), + DenyMessage: ResolveDenyMessage(config), + BlockedDetails: blocked, + } + return json.Marshal(body) +} + +// BuildOpenAIFallbackDenyResponseBody builds the guardrail JSON object embedded +// by the outer OpenAI structured template as choices[0].x_higress_guardrail for +// mask→block fallback paths. The fallback is triggered by +// ReplaceJsonFieldTextContent failure or empty desensitization, so there is no +// upstream Response object to derive blockedDetails from. +func BuildOpenAIFallbackDenyResponseBody(config AISecurityConfig) ([]byte, error) { + body := DenyResponseBody{ + Code: int(config.DenyCode), + DenyMessage: ResolveDenyMessage(config), + BlockedDetails: []BlockedDetail{}, + } + return json.Marshal(body) +} + +func openAIDenyContentType(isStream bool) [][2]string { + if isStream { + return [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}} + } + return [][2]string{{"content-type", "application/json"}} +} + +// BuildOpenAIDenyData builds the complete OpenAI-formatted deny response bytes +// (structured or legacy). Callers that need raw bytes (e.g. streaming response +// handlers using InjectEncodedDataToFilterChain) use this directly; callers that +// want a full SendHttpResponse dispatch should use SendDenyResponse instead. +func BuildOpenAIDenyData(config AISecurityConfig, response Response, consumer string, isStream bool) ([]byte, error) { + if config.OpenAIDenyResponseFormat == OpenAIDenyResponseFormatStructured { + guardrailBody, err := BuildOpenAIDenyResponseBody(response, config, consumer) + if err != nil { + return nil, err + } + marshalledDenyMessage := wrapper.MarshalStr(ResolveDenyMessage(config)) + randomID := utils.GenerateRandomChatID() + createdTs := time.Now().Unix() + if isStream { + return []byte(fmt.Sprintf(OpenAIStreamResponseFormatStructured, randomID, createdTs, marshalledDenyMessage, randomID, createdTs, string(guardrailBody))), nil + } + return []byte(fmt.Sprintf(OpenAIResponseFormatStructured, randomID, createdTs, marshalledDenyMessage, string(guardrailBody))), nil + } + denyBody, err := BuildDenyResponseBody(response, config, consumer) + if err != nil { + return nil, err + } + marshalledDenyBody := wrapper.MarshalStr(string(denyBody)) + randomID := utils.GenerateRandomChatID() + createdTs := time.Now().Unix() + if isStream { + return []byte(fmt.Sprintf(OpenAIStreamResponseFormatLegacy, randomID, createdTs, marshalledDenyBody, randomID, createdTs)), nil + } + return []byte(fmt.Sprintf(OpenAIResponseFormatLegacy, randomID, createdTs, marshalledDenyBody)), nil +} + +// SendDenyResponse dispatches a deny HTTP response in the appropriate format +// (ProtocolOriginal, Structured, or Legacy). It returns an error only if +// building the response body fails; the caller should handle the error +// (e.g. log, mark guardrail error, resume request/response). +func SendDenyResponse(config AISecurityConfig, response Response, consumer string, isStream bool) error { + if config.ProtocolOriginal { + denyBody, err := BuildDenyResponseBody(response, config, consumer) + if err != nil { + return err + } + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1) + return nil + } + data, err := BuildOpenAIDenyData(config, response, consumer, isStream) + if err != nil { + return err + } + proxywasm.SendHttpResponse(uint32(config.DenyCode), openAIDenyContentType(isStream), data, -1) + return nil +} + +// SendFallbackDenyResponse dispatches a fallback deny HTTP response when no +// upstream Response object is available (e.g. mask-to-block on replace error +// or empty desensitization). For ProtocolOriginal it sends the plain deny +// message; for OpenAI formats it wraps it in the appropriate template with +// an empty-blockedDetails guardrail object (structured) or the deny message +// as content (legacy). +func SendFallbackDenyResponse(config AISecurityConfig, isStream bool) error { + marshalledDenyMessage := wrapper.MarshalStr(ResolveDenyMessage(config)) + if config.ProtocolOriginal { + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) + return nil + } + randomID := utils.GenerateRandomChatID() + createdTs := time.Now().Unix() + if config.OpenAIDenyResponseFormat == OpenAIDenyResponseFormatStructured { + guardrailBody, err := BuildOpenAIFallbackDenyResponseBody(config) + if err != nil { + return err + } + var data []byte + if isStream { + data = []byte(fmt.Sprintf(OpenAIStreamResponseFormatStructured, randomID, createdTs, marshalledDenyMessage, randomID, createdTs, string(guardrailBody))) + } else { + data = []byte(fmt.Sprintf(OpenAIResponseFormatStructured, randomID, createdTs, marshalledDenyMessage, string(guardrailBody))) + } + proxywasm.SendHttpResponse(uint32(config.DenyCode), openAIDenyContentType(isStream), data, -1) + return nil + } + var data []byte + if isStream { + data = []byte(fmt.Sprintf(OpenAIStreamResponseFormatLegacy, randomID, createdTs, marshalledDenyMessage, randomID, createdTs)) + } else { + data = []byte(fmt.Sprintf(OpenAIResponseFormatLegacy, randomID, createdTs, marshalledDenyMessage)) + } + proxywasm.SendHttpResponse(uint32(config.DenyCode), openAIDenyContentType(isStream), data, -1) + return nil +} + func GetUnacceptableDetail(data Data, config AISecurityConfig, consumer string) []Detail { result := []Detail{} for _, detail := range data.Detail { diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text/openai.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text/openai.go index e736d55ac..55bd064d6 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text/openai.go +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text/openai.go @@ -3,7 +3,6 @@ package text import ( "bytes" "encoding/json" - "fmt" "net/http" "strings" "time" @@ -21,6 +20,7 @@ import ( const ( responseFallbackPathsCtxKey = "response_fallback_paths" responseStreamFallbackPathsCtxKey = "response_stream_fallback_paths" + responseStartTimeCtxKey = "response_start_time" ) func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action { @@ -28,6 +28,7 @@ func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISe ctx.SetContext("end_of_stream_received", false) ctx.SetContext("during_call", false) ctx.SetContext("risk_detected", false) + ctx.SetContext(responseStartTimeCtxKey, time.Now().UnixMilli()) ctx.SetContext(responseFallbackPathsCtxKey, buildEffectiveFallbackPaths(config.ResponseContentJsonPath, config.ResponseContentFallbackJsonPaths)) ctx.SetContext(responseStreamFallbackPathsCtxKey, buildEffectiveFallbackPaths(config.ResponseStreamContentJsonPath, config.ResponseStreamContentFallbackJsonPaths)) sessionID, _ := utils.GenerateHexID(20) @@ -52,10 +53,13 @@ func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config c sessionID, _ = ctx.GetContext("sessionID").(string) } var bufferQueue [][]byte + currentSubmissionIndex := 0 var singleCall func() callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { log.Info(string(responseBody)) if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + startTime, _ := ctx.GetContext(responseStartTimeCtxKey).(int64) + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime) if ctx.GetContext("end_of_stream_received").(bool) { proxywasm.ResumeHttpResponse() } @@ -66,6 +70,8 @@ func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config c err := json.Unmarshal(responseBody, &response) if err != nil { log.Error("failed to unmarshal aliyun content security response at response phase") + startTime, _ := ctx.GetContext(responseStartTimeCtxKey).(int64) + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime) if ctx.GetContext("end_of_stream_received").(bool) { proxywasm.ResumeHttpResponse() } @@ -73,24 +79,50 @@ func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config c return } if !cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) { - denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) + jsonData, err := cfg.BuildOpenAIDenyData(config, response, consumer, true) if err != nil { + // Build failure → fail-open: inject the buffered upstream content as-is. + // Make this path observable so operators can spot the silent passthrough + // instead of mistakenly attributing observed denies-only to the success + // path's metrics. Symmetric with the success path's observability suite + // (counter / safecheck_response_rt / safecheck_status / log / risk_detected). + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultError) log.Errorf("failed to build deny response body: %v", err) endStream := ctx.GetContext("end_of_stream_received").(bool) && ctx.BufferQueueSize() == 0 proxywasm.InjectEncodedDataToFilterChain(bytes.Join(bufferQueue, []byte("")), endStream) bufferQueue = [][]byte{} + config.IncrementCounter("ai_sec_response_deny_buildfail", 1) + startTime, _ := ctx.GetContext(responseStartTimeCtxKey).(int64) + ctx.SetUserAttribute("safecheck_response_rt", time.Now().UnixMilli()-startTime) + ctx.SetUserAttribute("safecheck_status", "build_fallback_pass") + if len(response.Data.Result) > 0 { + ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) + ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) + } + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) if !endStream { ctx.SetContext("during_call", false) singleCall() } return } - marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) - randomID := utils.GenerateRandomChatID() - jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny) proxywasm.InjectEncodedDataToFilterChain(jsonData, true) + ctx.SetContext("risk_detected", true) + ctx.SetContext("during_call", false) + config.IncrementCounter("ai_sec_response_deny", 1) + startTime, _ := ctx.GetContext(responseStartTimeCtxKey).(int64) + ctx.SetUserAttribute("safecheck_response_rt", time.Now().UnixMilli()-startTime) + ctx.SetUserAttribute("safecheck_status", "response deny") + if len(response.Data.Result) > 0 { + ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) + ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) + } + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) return } + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass) + cfg.WriteGuardrailLog(ctx) endStream := ctx.GetContext("end_of_stream_received").(bool) && ctx.BufferQueueSize() == 0 proxywasm.InjectEncodedDataToFilterChain(bytes.Join(bufferQueue, []byte("")), endStream) bufferQueue = [][]byte{} @@ -126,10 +158,13 @@ func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config c ctx.SetContext("during_call", true) log.Debugf("current content piece: %s", buffer) checkService := config.GetResponseCheckService(consumer) + currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseResponse, cfg.GuardrailModalityText) path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, buffer, sessionID) err := config.Client.Post(path, headers, body, callback, config.Timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) + startTime, _ := ctx.GetContext(responseStartTimeCtxKey).(int64) + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, nil, startTime) if ctx.GetContext("end_of_stream_received").(bool) { proxywasm.ResumeHttpResponse() } @@ -198,10 +233,12 @@ func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecu } contentIndex := 0 sessionID, _ := utils.GenerateHexID(20) + currentSubmissionIndex := 0 var singleCall func() callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { log.Info(string(responseBody)) if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpResponse() return } @@ -209,6 +246,7 @@ func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecu err := json.Unmarshal(responseBody, &response) if err != nil { log.Error("failed to unmarshal aliyun content security response at response phase") + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpResponse() return } @@ -217,41 +255,32 @@ func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecu endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "response pass") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass) + if contentIndex >= len(content) { + cfg.WriteGuardrailLog(ctx) proxywasm.ResumeHttpResponse() } else { singleCall() } return } - denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) - if err != nil { + if err := cfg.SendDenyResponse(config, response, consumer, isStreamingResponse); err != nil { log.Errorf("failed to build deny response body: %v", err) + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpResponse() return } - if config.ProtocolOriginal { - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1) - } else if isStreamingResponse { - randomID := utils.GenerateRandomChatID() - marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) - jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) - } else { - randomID := utils.GenerateRandomChatID() - marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) - jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage)) - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) - } config.IncrementCounter("ai_sec_response_deny", 1) endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "response deny") - if response.Data.Advice != nil { + if len(response.Data.Result) > 0 { ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) } - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny) + cfg.WriteGuardrailLog(ctx) } singleCall = func() { var nextContentIndex int @@ -264,10 +293,12 @@ func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecu contentIndex = nextContentIndex log.Debugf("current content piece: %s", contentPiece) checkService := config.GetResponseCheckService(consumer) + currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseResponse, cfg.GuardrailModalityText) path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, contentPiece, sessionID) err := config.Client.Post(path, headers, body, callback, config.Timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, nil, startTime) proxywasm.ResumeHttpResponse() } } diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/openai.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/openai.go index 61c84e10b..37d8d931b 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/openai.go +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/openai.go @@ -54,11 +54,14 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg. contentIndex := 0 imageIndex := 0 sessionID, _ := utils.GenerateHexID(20) + currentSubmissionIndex := 0 + currentImageSubmissionIndex := 0 var singleCall func() var singleCallForImage func() callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { log.Info(string(responseBody)) if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } @@ -66,6 +69,7 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg. err := json.Unmarshal(responseBody, &response) if err != nil { log.Errorf("%+v", err) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } @@ -74,10 +78,13 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg. endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "request pass") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass) + if contentIndex >= len(content) { if len(images) > 0 && config.CheckRequestImage { singleCallForImage() } else { + cfg.WriteGuardrailLog(ctx) proxywasm.ResumeHttpRequest() } } else { @@ -88,6 +95,7 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg. denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) if err != nil { log.Errorf("failed to build deny response body: %v", err) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } @@ -97,13 +105,15 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg. endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "reqeust deny") - if response.Data.Advice != nil { + if len(response.Data.Result) > 0 { ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) } - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny) + cfg.WriteGuardrailLog(ctx) } singleCall = func() { + currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText) var nextContentIndex int if contentIndex+cfg.LengthLimit >= len(content) { nextContentIndex = len(content) @@ -117,6 +127,7 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg. err := config.Client.Post(path, headers, body, callback, config.Timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, nil, startTime) proxywasm.ResumeHttpRequest() } } @@ -125,6 +136,7 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg. imageIndex += 1 log.Info(string(responseBody)) if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime) if imageIndex < len(images) { singleCallForImage() } else { @@ -136,6 +148,7 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg. err := json.Unmarshal(responseBody, &response) if err != nil { log.Errorf("%+v", err) + cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime) if imageIndex < len(images) { singleCallForImage() } else { @@ -148,7 +161,10 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg. if imageIndex >= len(images) { ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "request pass") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + cfg.CompleteGuardrailSubmissionEvent(ctx, currentImageSubmissionIndex, responseBody, cfg.GuardrailResultPass) + if imageIndex >= len(images) { + cfg.WriteGuardrailLog(ctx) proxywasm.ResumeHttpRequest() } else { singleCallForImage() @@ -159,6 +175,7 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg. denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) if err != nil { log.Errorf("failed to build deny response body: %v", err) + cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } @@ -167,13 +184,15 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg. config.IncrementCounter("ai_sec_request_deny", 1) ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "reqeust deny") - if response.Data.Advice != nil { + if len(response.Data.Result) > 0 { ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) } - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + cfg.CompleteGuardrailSubmissionEvent(ctx, currentImageSubmissionIndex, responseBody, cfg.GuardrailResultDeny) + cfg.WriteGuardrailLog(ctx) } singleCallForImage = func() { + currentImageSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityImage) img := images[imageIndex] imgUrl := "" imgBase64 := "" @@ -186,6 +205,7 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg. err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) + cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, nil, startTime) proxywasm.ResumeHttpRequest() } } @@ -207,11 +227,13 @@ func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg return types.ActionContinue } imageIndex := 0 + currentSubmissionIndex := 0 var singleCall func() callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { imageIndex += 1 log.Info(string(responseBody)) if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime) if imageIndex < len(imgResults) { singleCall() } else { @@ -223,6 +245,7 @@ func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg err := json.Unmarshal(responseBody, &response) if err != nil { log.Errorf("%+v", err) + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime) if imageIndex < len(imgResults) { singleCall() } else { @@ -233,9 +256,16 @@ func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg endTime := time.Now().UnixMilli() if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) { if imageIndex >= len(imgResults) { + ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) + // Transitional double-write: this handler historically wrote safecheck_request_rt + // for response-phase emissions; existing dashboards key off the old name. + // Remove after 1–2 release cycles. ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) - ctx.SetUserAttribute("safecheck_status", "request pass") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + ctx.SetUserAttribute("safecheck_status", "response pass") + } + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass) + if imageIndex >= len(imgResults) { + cfg.WriteGuardrailLog(ctx) proxywasm.ResumeHttpResponse() } else { singleCall() @@ -245,16 +275,25 @@ func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) if err != nil { log.Errorf("failed to build deny response body: %v", err) + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpResponse() return } proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, denyBody, -1) + config.IncrementCounter("ai_sec_response_deny", 1) + // Transitional double-write: this handler historically incremented ai_sec_request_deny + // and wrote safecheck_request_rt for response-phase denies; existing dashboards/alerts + // key off the old names. The legacy safecheck_status="reqeust deny" (typo) is dropped + // in this transition. Remove after 1–2 release cycles. config.IncrementCounter("ai_sec_request_deny", 1) + ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) - ctx.SetUserAttribute("safecheck_status", "reqeust deny") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + ctx.SetUserAttribute("safecheck_status", "response deny") + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny) + cfg.WriteGuardrailLog(ctx) } singleCall = func() { + currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseResponse, cfg.GuardrailModalityImage) img := imgResults[imageIndex] imgUrl := "" imgBase64 := "" @@ -267,6 +306,7 @@ func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg err := config.Client.Post(path, headers, body, callback, config.Timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, nil, startTime) proxywasm.ResumeHttpResponse() } } diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/qwen.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/qwen.go index daefca679..df6b3904a 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/qwen.go +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/qwen.go @@ -212,11 +212,14 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI contentIndex := 0 imageIndex := 0 sessionID, _ := utils.GenerateHexID(20) + currentSubmissionIndex := 0 + currentImageSubmissionIndex := 0 var singleCall func() var singleCallForImage func() callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { log.Info(string(responseBody)) if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } @@ -224,6 +227,7 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI err := json.Unmarshal(responseBody, &response) if err != nil { log.Errorf("%+v", err) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } @@ -232,10 +236,13 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "request pass") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass) + if contentIndex >= len(content) { if len(images) > 0 && config.CheckRequestImage { singleCallForImage() } else { + cfg.WriteGuardrailLog(ctx) proxywasm.ResumeHttpRequest() } } else { @@ -246,6 +253,7 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) if err != nil { log.Errorf("failed to build deny response body: %v", err) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } @@ -255,13 +263,15 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "reqeust deny") - if response.Data.Advice != nil { + if len(response.Data.Result) > 0 { ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) } - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny) + cfg.WriteGuardrailLog(ctx) } singleCall = func() { + currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText) var nextContentIndex int if contentIndex+cfg.LengthLimit >= len(content) { nextContentIndex = len(content) @@ -275,6 +285,7 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI err := config.Client.Post(path, headers, body, callback, config.Timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, nil, startTime) proxywasm.ResumeHttpRequest() } } @@ -283,6 +294,7 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI imageIndex += 1 log.Info(string(responseBody)) if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime) if imageIndex < len(images) { singleCallForImage() } else { @@ -294,6 +306,7 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI err := json.Unmarshal(responseBody, &response) if err != nil { log.Errorf("%+v", err) + cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime) if imageIndex < len(images) { singleCallForImage() } else { @@ -306,7 +319,10 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI if imageIndex >= len(images) { ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "request pass") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + cfg.CompleteGuardrailSubmissionEvent(ctx, currentImageSubmissionIndex, responseBody, cfg.GuardrailResultPass) + if imageIndex >= len(images) { + cfg.WriteGuardrailLog(ctx) proxywasm.ResumeHttpRequest() } else { singleCallForImage() @@ -317,6 +333,7 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) if err != nil { log.Errorf("failed to build deny response body: %v", err) + cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } @@ -325,13 +342,15 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI config.IncrementCounter("ai_sec_request_deny", 1) ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "reqeust deny") - if response.Data.Advice != nil { + if len(response.Data.Result) > 0 { ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) } - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + cfg.CompleteGuardrailSubmissionEvent(ctx, currentImageSubmissionIndex, responseBody, cfg.GuardrailResultDeny) + cfg.WriteGuardrailLog(ctx) } singleCallForImage = func() { + currentImageSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityImage) img := images[imageIndex] imgUrl := "" imgBase64 := "" @@ -344,6 +363,7 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) + cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, nil, startTime) proxywasm.ResumeHttpRequest() } } @@ -365,11 +385,13 @@ func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.A return types.ActionContinue } imageIndex := 0 + currentSubmissionIndex := 0 var singleCall func() callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { imageIndex += 1 log.Info(string(responseBody)) if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime) if imageIndex < len(imgUrls) { singleCall() } else { @@ -381,6 +403,7 @@ func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.A err := json.Unmarshal(responseBody, &response) if err != nil { log.Errorf("%+v", err) + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime) if imageIndex < len(imgUrls) { singleCall() } else { @@ -391,9 +414,16 @@ func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.A endTime := time.Now().UnixMilli() if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) { if imageIndex >= len(imgUrls) { + ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) + // Transitional double-write: this handler historically wrote safecheck_request_rt + // for response-phase emissions; existing dashboards key off the old name. + // Remove after 1–2 release cycles. ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) - ctx.SetUserAttribute("safecheck_status", "request pass") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + ctx.SetUserAttribute("safecheck_status", "response pass") + } + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass) + if imageIndex >= len(imgUrls) { + cfg.WriteGuardrailLog(ctx) proxywasm.ResumeHttpResponse() } else { singleCall() @@ -403,21 +433,31 @@ func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.A denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) if err != nil { log.Errorf("failed to build deny response body: %v", err) + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpResponse() return } proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, denyBody, -1) + config.IncrementCounter("ai_sec_response_deny", 1) + // Transitional double-write: this handler historically incremented ai_sec_request_deny + // and wrote safecheck_request_rt for response-phase denies; existing dashboards/alerts + // key off the old names. The legacy safecheck_status="reqeust deny" (typo) is dropped + // in this transition. Remove after 1–2 release cycles. config.IncrementCounter("ai_sec_request_deny", 1) + ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) - ctx.SetUserAttribute("safecheck_status", "reqeust deny") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + ctx.SetUserAttribute("safecheck_status", "response deny") + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny) + cfg.WriteGuardrailLog(ctx) } singleCall = func() { + currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseResponse, cfg.GuardrailModalityImage) imgUrl := imgUrls[imageIndex] path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, "") err := config.Client.Post(path, headers, body, callback, config.Timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, nil, startTime) proxywasm.ResumeHttpResponse() } } diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/mcp/mcp.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/mcp/mcp.go index e88041e1b..b2607a40e 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/mcp/mcp.go +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/mcp/mcp.go @@ -43,10 +43,12 @@ func HandleMcpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, } contentIndex := 0 sessionID, _ := utils.GenerateHexID(20) + currentSubmissionIndex := 0 var singleCall func() callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { log.Info(string(responseBody)) if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } @@ -54,6 +56,7 @@ func HandleMcpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, err := json.Unmarshal(responseBody, &response) if err != nil { log.Errorf("%+v", err) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } @@ -62,7 +65,10 @@ func HandleMcpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "request pass") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass) + if contentIndex >= len(content) { + cfg.WriteGuardrailLog(ctx) proxywasm.ResumeHttpRequest() } else { singleCall() @@ -74,22 +80,25 @@ func HandleMcpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "request deny") - if response.Data.Advice != nil { + if len(response.Data.Result) > 0 { ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) } - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) if err != nil { log.Errorf("failed to build deny response body: %v", err) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny) + cfg.WriteGuardrailLog(ctx) marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) denyResponse := fmt.Sprintf(DenyResponse, marshalledDenyMessage) proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(denyResponse), -1) } singleCall = func() { + currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityMCP) var nextContentIndex int if contentIndex+cfg.LengthLimit >= len(content) { nextContentIndex = len(content) @@ -103,6 +112,7 @@ func HandleMcpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, err := config.Client.Post(path, headers, body, callback, config.Timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, nil, startTime) proxywasm.ResumeHttpRequest() } } @@ -114,6 +124,7 @@ func HandleMcpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, func HandleMcpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte { consumer, _ := ctx.GetContext("consumer").(string) var frontBuffer []byte + currentSubmissionIndex := 0 var singleCall func() callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { defer func() { @@ -122,6 +133,9 @@ func HandleMcpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecuri }() log.Info(string(responseBody)) if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + ctx.SetUserAttribute("safecheck_status", "response error") + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultError) + cfg.WriteGuardrailLog(ctx) proxywasm.InjectEncodedDataToFilterChain(frontBuffer, false) return } @@ -129,6 +143,9 @@ func HandleMcpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecuri err := json.Unmarshal(responseBody, &response) if err != nil { log.Error("failed to unmarshal aliyun content security response at response phase") + ctx.SetUserAttribute("safecheck_status", "response error") + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultError) + cfg.WriteGuardrailLog(ctx) proxywasm.InjectEncodedDataToFilterChain(frontBuffer, false) return } @@ -136,13 +153,20 @@ func HandleMcpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecuri denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) if err != nil { log.Errorf("failed to build deny response body: %v", err) + ctx.SetUserAttribute("safecheck_status", "response error") + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultError) + cfg.WriteGuardrailLog(ctx) proxywasm.InjectEncodedDataToFilterChain(frontBuffer, false) return } + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny) + cfg.WriteGuardrailLog(ctx) marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) denySSEResponse := fmt.Sprintf(DenySSEResponse, marshalledDenyMessage) proxywasm.InjectEncodedDataToFilterChain([]byte(denySSEResponse), true) } else { + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass) + cfg.WriteGuardrailLog(ctx) proxywasm.InjectEncodedDataToFilterChain(frontBuffer, false) } } @@ -158,10 +182,14 @@ func HandleMcpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecuri ctx.SetContext("during_call", true) checkService := config.GetResponseCheckService(consumer) sessionID, _ := utils.GenerateHexID(20) + currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseResponse, cfg.GuardrailModalityMCP) path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, msg, sessionID) err := config.Client.Post(path, headers, body, callback, config.Timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) + ctx.SetUserAttribute("safecheck_status", "response error") + cfg.CompleteGuardrailSubmissionEventWithRequestID(ctx, currentSubmissionIndex, "", cfg.GuardrailResultError) + cfg.WriteGuardrailLog(ctx) proxywasm.InjectEncodedDataToFilterChain(frontBuffer, false) ctx.SetContext("during_call", false) } @@ -194,10 +222,12 @@ func HandleMcpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, } contentIndex := 0 sessionID, _ := utils.GenerateHexID(20) + currentSubmissionIndex := 0 var singleCall func() callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { log.Info(string(responseBody)) if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpResponse() return } @@ -205,6 +235,7 @@ func HandleMcpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, err := json.Unmarshal(responseBody, &response) if err != nil { log.Error("failed to unmarshal aliyun content security response at response phase") + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpResponse() return } @@ -213,7 +244,10 @@ func HandleMcpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "response pass") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass) + if contentIndex >= len(content) { + cfg.WriteGuardrailLog(ctx) proxywasm.ResumeHttpResponse() } else { singleCall() @@ -224,17 +258,19 @@ func HandleMcpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "response deny") - if response.Data.Advice != nil { + if len(response.Data.Result) > 0 { ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) } - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) if err != nil { log.Errorf("failed to build deny response body: %v", err) + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpResponse() return } + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny) + cfg.WriteGuardrailLog(ctx) marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) denyResponseBody := fmt.Sprintf(DenyResponse, marshalledDenyMessage) proxywasm.RemoveHttpResponseHeader("content-length") @@ -253,10 +289,12 @@ func HandleMcpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, contentIndex = nextContentIndex log.Debugf("current content piece: %s", contentPiece) checkService := config.GetResponseCheckService(consumer) + currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseResponse, cfg.GuardrailModalityMCP) path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, contentPiece, sessionID) err := config.Client.Post(path, headers, body, callback, config.Timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) + cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, nil, startTime) proxywasm.ResumeHttpResponse() } } diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text/openai.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text/openai.go index 25e41be19..70695c637 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text/openai.go +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text/openai.go @@ -2,7 +2,6 @@ package text import ( "encoding/json" - "fmt" "net/http" "strings" "time" @@ -67,6 +66,8 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur hasMasked := false maskedContent := []byte(content) sessionID, _ := utils.GenerateHexID(20) + currentSubmissionIndex := 0 + currentImageSubmissionIndex := 0 var singleCall func() var singleCallForImage func() // prevContentIndex tracks the start of the current chunk for masking replacement @@ -74,6 +75,7 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { log.Info(string(responseBody)) if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } @@ -81,6 +83,7 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur err := json.Unmarshal(responseBody, &response) if err != nil { log.Errorf("%+v", err) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } @@ -97,26 +100,17 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur if replaceErr != nil { log.Errorf("failed to replace request body content, falling back to block: %v", replaceErr) // Fall back to block to prevent leaking sensitive data - denyMessage := cfg.DefaultDenyMessage - if config.DenyMessage != "" { - denyMessage = config.DenyMessage - } - marshalledDenyMessage := wrapper.MarshalStr(denyMessage) - if config.ProtocolOriginal { - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) - } else if gjson.GetBytes(body, "stream").Bool() { - randomID := utils.GenerateRandomChatID() - jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) - } else { - randomID := utils.GenerateRandomChatID() - jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage)) - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) + if sendErr := cfg.SendFallbackDenyResponse(config, gjson.GetBytes(body, "stream").Bool()); sendErr != nil { + log.Errorf("failed to build deny response body: %v", sendErr) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) + proxywasm.ResumeHttpRequest() + return } ctx.DontReadResponseBody() config.IncrementCounter("ai_sec_request_deny", 1) ctx.SetUserAttribute("safecheck_status", "reqeust deny") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny) + cfg.WriteGuardrailLog(ctx) return } proxywasm.ReplaceHttpRequestBody(newBody) @@ -125,10 +119,13 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur } else { ctx.SetUserAttribute("safecheck_status", "request pass") } - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass) + if contentIndex >= len(maskedContent) { if len(images) > 0 && config.CheckRequestImage { singleCallForImage() } else { + cfg.WriteGuardrailLog(ctx) proxywasm.ResumeHttpRequest() } } else { @@ -140,6 +137,30 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur if desensitization == "" { proxywasm.LogInfof("safecheck_action_source=mask_fallback_to_block, reason=empty_desensitization") log.Warnf("desensitization content is empty, falling back to block logic") + // Keep this fallback separate from RiskBlock: legacy reuses the + // original deny body in content, while structured emits an empty + // fallback guardrail object. + isStream := gjson.GetBytes(body, "stream").Bool() + var sendErr error + if !config.ProtocolOriginal && config.OpenAIDenyResponseFormat != cfg.OpenAIDenyResponseFormatStructured { + sendErr = cfg.SendDenyResponse(config, response, consumer, isStream) + } else { + sendErr = cfg.SendFallbackDenyResponse(config, isStream) + } + if sendErr != nil { + log.Errorf("failed to build deny response body: %v", sendErr) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) + proxywasm.ResumeHttpRequest() + return + } + ctx.DontReadResponseBody() + config.IncrementCounter("ai_sec_request_deny", 1) + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "reqeust deny") + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny) + cfg.WriteGuardrailLog(ctx) + return } else { // Replace only the current chunk portion in maskedContent chunkStart := prevContentIndex @@ -156,28 +177,19 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur if replaceErr != nil { log.Errorf("failed to replace request body content, falling back to block: %v", replaceErr) // Fall back to block to prevent leaking sensitive data - denyMessage := cfg.DefaultDenyMessage - if config.DenyMessage != "" { - denyMessage = config.DenyMessage - } - marshalledDenyMessage := wrapper.MarshalStr(denyMessage) - if config.ProtocolOriginal { - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) - } else if gjson.GetBytes(body, "stream").Bool() { - randomID := utils.GenerateRandomChatID() - jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) - } else { - randomID := utils.GenerateRandomChatID() - jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage)) - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) + if sendErr := cfg.SendFallbackDenyResponse(config, gjson.GetBytes(body, "stream").Bool()); sendErr != nil { + log.Errorf("failed to build deny response body: %v", sendErr) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) + proxywasm.ResumeHttpRequest() + return } ctx.DontReadResponseBody() config.IncrementCounter("ai_sec_request_deny", 1) endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "reqeust deny") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny) + cfg.WriteGuardrailLog(ctx) return } proxywasm.ReplaceHttpRequestBody(newBody) @@ -185,52 +197,41 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "request mask") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultMask) if len(images) > 0 && config.CheckRequestImage { singleCallForImage() } else { + cfg.WriteGuardrailLog(ctx) proxywasm.ResumeHttpRequest() } } else { + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultMask) singleCall() } return } - // Fall through to block logic when desensitization is empty - fallthrough case cfg.RiskBlock: - denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) - if err != nil { + if err := cfg.SendDenyResponse(config, response, consumer, gjson.GetBytes(body, "stream").Bool()); err != nil { log.Errorf("failed to build deny response body: %v", err) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } - if config.ProtocolOriginal { - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1) - } else if gjson.GetBytes(body, "stream").Bool() { - randomID := utils.GenerateRandomChatID() - marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) - jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) - } else { - randomID := utils.GenerateRandomChatID() - marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) - jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage)) - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) - } ctx.DontReadResponseBody() config.IncrementCounter("ai_sec_request_deny", 1) endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "reqeust deny") - if response.Data.Advice != nil { + if len(response.Data.Result) > 0 { ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) } - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny) + cfg.WriteGuardrailLog(ctx) } } singleCall = func() { + currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText) prevContentIndex = contentIndex var nextContentIndex int if contentIndex+cfg.LengthLimit >= len(maskedContent) { @@ -245,6 +246,7 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur err := config.Client.Post(path, headers, body, callback, config.Timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, nil, startTime) proxywasm.ResumeHttpRequest() } } @@ -253,6 +255,7 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur imageIndex += 1 log.Info(string(responseBody)) if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime) if imageIndex < len(images) { singleCallForImage() } else { @@ -264,6 +267,7 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur err := json.Unmarshal(responseBody, &response) if err != nil { log.Errorf("%+v", err) + cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime) if imageIndex < len(images) { singleCallForImage() } else { @@ -276,7 +280,10 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur if imageIndex >= len(images) { ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "request pass") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + cfg.CompleteGuardrailSubmissionEvent(ctx, currentImageSubmissionIndex, responseBody, cfg.GuardrailResultPass) + if imageIndex >= len(images) { + cfg.WriteGuardrailLog(ctx) proxywasm.ResumeHttpRequest() } else { singleCallForImage() @@ -284,36 +291,25 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur return } - denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) - if err != nil { + if err := cfg.SendDenyResponse(config, response, consumer, gjson.GetBytes(body, "stream").Bool()); err != nil { log.Errorf("failed to build deny response body: %v", err) + cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } - if config.ProtocolOriginal { - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1) - } else if gjson.GetBytes(body, "stream").Bool() { - randomID := utils.GenerateRandomChatID() - marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) - jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) - } else { - randomID := utils.GenerateRandomChatID() - marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) - jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage)) - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) - } ctx.DontReadResponseBody() config.IncrementCounter("ai_sec_request_deny", 1) ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "reqeust deny") - if response.Data.Advice != nil { + if len(response.Data.Result) > 0 { ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) } - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + cfg.CompleteGuardrailSubmissionEvent(ctx, currentImageSubmissionIndex, responseBody, cfg.GuardrailResultDeny) + cfg.WriteGuardrailLog(ctx) } singleCallForImage = func() { + currentImageSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityImage) img := images[imageIndex] imgUrl := "" imgBase64 := "" @@ -326,6 +322,7 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) + cfg.MarkGuardrailRequestError(ctx, currentImageSubmissionIndex, nil, startTime) proxywasm.ResumeHttpRequest() } } diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus/text/openai.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus/text/openai.go index 31c82fac8..e8b9a01cc 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus/text/openai.go +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus/text/openai.go @@ -2,7 +2,6 @@ package text import ( "encoding/json" - "fmt" "net/http" "time" @@ -27,10 +26,12 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur } contentIndex := 0 sessionID, _ := utils.GenerateHexID(20) + currentSubmissionIndex := 0 var singleCall func() callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { log.Info(string(responseBody)) if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } @@ -38,6 +39,7 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur err := json.Unmarshal(responseBody, &response) if err != nil { log.Error("failed to unmarshal aliyun content security response at request phase") + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } @@ -46,44 +48,36 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "request pass") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass) + if contentIndex >= len(content) { + cfg.WriteGuardrailLog(ctx) proxywasm.ResumeHttpRequest() } else { singleCall() } return } - denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) - if err != nil { + if err := cfg.SendDenyResponse(config, response, consumer, gjson.GetBytes(body, "stream").Bool()); err != nil { log.Errorf("failed to build deny response body: %v", err) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, responseBody, startTime) proxywasm.ResumeHttpRequest() return } - if config.ProtocolOriginal { - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1) - } else if gjson.GetBytes(body, "stream").Bool() { - randomID := utils.GenerateRandomChatID() - marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) - jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) - } else { - randomID := utils.GenerateRandomChatID() - marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) - jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage)) - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) - } ctx.DontReadResponseBody() config.IncrementCounter("ai_sec_request_deny", 1) endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "reqeust deny") - if response.Data.Advice != nil { + if len(response.Data.Result) > 0 { ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) } - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny) + cfg.WriteGuardrailLog(ctx) } singleCall = func() { + currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText) var nextContentIndex int if contentIndex+cfg.LengthLimit >= len(content) { nextContentIndex = len(content) @@ -97,6 +91,7 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur err := config.Client.Post(path, headers, body, callback, config.Timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) + cfg.MarkGuardrailRequestError(ctx, currentSubmissionIndex, nil, startTime) proxywasm.ResumeHttpRequest() } } diff --git a/plugins/wasm-go/extensions/ai-security-guard/main_test.go b/plugins/wasm-go/extensions/ai-security-guard/main_test.go index 841fdb3b4..9afe34485 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main_test.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main_test.go @@ -312,6 +312,29 @@ var protocolOriginalConfig = func() json.RawMessage { return data }() +func withConfigOverrides(base json.RawMessage, overrides map[string]interface{}) json.RawMessage { + var config map[string]interface{} + _ = json.Unmarshal(base, &config) + for k, v := range overrides { + config[k] = v + } + data, _ := json.Marshal(config) + return data +} + +func withStructuredFormat(base json.RawMessage) json.RawMessage { + return withConfigOverrides(base, map[string]interface{}{ + "openAIDenyResponseFormat": string(cfg.OpenAIDenyResponseFormatStructured), + }) +} + +func mustDecodeLegacyDenyContent(t *testing.T, content string) cfg.DenyResponseBody { + t.Helper() + var denyBody cfg.DenyResponseBody + require.NoError(t, json.Unmarshal([]byte(content), &denyBody)) + return denyBody +} + func TestParseConfig(t *testing.T) { test.RunGoTest(t, func(t *testing.T) { // 测试基础配置解析 @@ -335,6 +358,69 @@ func TestParseConfig(t *testing.T) { require.Equal(t, 1000, securityConfig.BufferLimit) require.Equal(t, cfg.DefaultResponseFallbackJsonPaths(), securityConfig.ResponseContentFallbackJsonPaths) require.Equal(t, cfg.DefaultStreamingResponseFallbackJsonPaths(), securityConfig.ResponseStreamContentFallbackJsonPaths) + require.Equal(t, cfg.OpenAIDenyResponseFormatLegacy, securityConfig.OpenAIDenyResponseFormat) + }) + + t.Run("openai deny response format explicit legacy", func(t *testing.T) { + host, status := test.NewTestHost(withConfigOverrides(requestOnlyConfig, map[string]interface{}{ + "openAIDenyResponseFormat": "legacy", + })) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + securityConfig := config.(*cfg.AISecurityConfig) + require.Equal(t, cfg.OpenAIDenyResponseFormatLegacy, securityConfig.OpenAIDenyResponseFormat) + }) + + t.Run("openai deny response format explicit structured", func(t *testing.T) { + host, status := test.NewTestHost(withStructuredFormat(requestOnlyConfig)) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + securityConfig := config.(*cfg.AISecurityConfig) + require.Equal(t, cfg.OpenAIDenyResponseFormatStructured, securityConfig.OpenAIDenyResponseFormat) + }) + + t.Run("invalid openai deny response format", func(t *testing.T) { + host, status := test.NewTestHost(withConfigOverrides(requestOnlyConfig, map[string]interface{}{ + "openAIDenyResponseFormat": "json", + })) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) + + t.Run("empty openai deny response format is invalid", func(t *testing.T) { + host, status := test.NewTestHost(withConfigOverrides(requestOnlyConfig, map[string]interface{}{ + "openAIDenyResponseFormat": "", + })) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) + + t.Run("consumer risk level cannot override openai deny response format", func(t *testing.T) { + configJSON, err := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "action": "MultiModalGuard", + "contentModerationLevelBar": "high", + "consumerRiskLevel": []map[string]interface{}{ + { + "name": "consumer-a", + "matchType": "exact", + "openAIDenyResponseFormat": "structured", + }, + }, + }) + require.NoError(t, err) + var securityConfig cfg.AISecurityConfig + parseErr := securityConfig.Parse(gjson.ParseBytes(configJSON)) + require.EqualError(t, parseErr, cfg.OpenAIDenyResponseFormatConsumerScopeError) }) // 测试仅检查请求的配置 @@ -509,7 +595,7 @@ func TestOnHttpRequestHeaders(t *testing.T) { test.RunTest(t, func(t *testing.T) { // 测试启用请求检查的情况 t.Run("request checking enabled", func(t *testing.T) { - host, status := test.NewTestHost(basicConfig) + host, status := test.NewTestHost(withStructuredFormat(basicConfig)) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) @@ -547,7 +633,7 @@ func TestOnHttpRequestBody(t *testing.T) { test.RunTest(t, func(t *testing.T) { // 测试请求体安全检查通过 t.Run("request body security check pass", func(t *testing.T) { - host, status := test.NewTestHost(basicConfig) + host, status := test.NewTestHost(withStructuredFormat(basicConfig)) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) @@ -2655,7 +2741,7 @@ func TestTextModerationPlusResponseDeny(t *testing.T) { test.RunTest(t, func(t *testing.T) { // TextModerationPlus response deny → exercises text_moderation_plus/text (via common/text) BuildDenyResponseBody response path t.Run("text moderation plus response deny returns blockedDetails", func(t *testing.T) { - host, status := test.NewTestHost(basicConfig) + host, status := test.NewTestHost(withStructuredFormat(basicConfig)) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) @@ -2684,22 +2770,26 @@ func TestTextModerationPlusResponseDeny(t *testing.T) { require.NotNil(t, local, "expected SendHttpResponse for response deny") require.Contains(t, string(local.Data), "blockedDetails") - // Verify OpenAI completion shape wrapper + // Verify OpenAI completion shape wrapper in structured mode: + // message.content carries only the human-readable deny text and the + // structured deny payload moves to choices[0].x_higress_guardrail. type openAIChatCompletion struct { Choices []struct { Message struct { Content string `json:"content"` } `json:"message"` + Guardrail cfg.DenyResponseBody `json:"x_higress_guardrail"` } `json:"choices"` } var outer openAIChatCompletion require.NoError(t, json.Unmarshal(local.Data, &outer)) require.Len(t, outer.Choices, 1) - var deny cfg.DenyResponseBody - require.NoError(t, json.Unmarshal([]byte(outer.Choices[0].Message.Content), &deny)) - require.Equal(t, 200, deny.Code) - require.NotEmpty(t, deny.BlockedDetails) + require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Message.Content) + require.Equal(t, 200, outer.Choices[0].Guardrail.Code) + require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Guardrail.DenyMessage) + require.NotEmpty(t, outer.Choices[0].Guardrail.BlockedDetails) + require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists()) }) }) } @@ -3289,3 +3379,835 @@ func TestMultiModalGuardMaskStreamDeny(t *testing.T) { }) }) } + +func TestOpenAIDenyLegacyDefaultNonStream(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("default legacy non-stream response keeps deny body in content", func(t *testing.T) { + host, status := test.NewTestHost(multiModalGuardTextConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "trigger deny"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-legacy-default", "Data": {"RiskLevel": "high"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local) + + require.Equal(t, "stop", gjson.GetBytes(local.Data, "choices.0.finish_reason").String()) + require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").Exists()) + require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists()) + + content := gjson.GetBytes(local.Data, "choices.0.message.content").String() + denyBody := mustDecodeLegacyDenyContent(t, content) + require.Equal(t, 200, denyBody.Code) + require.NotEmpty(t, denyBody.BlockedDetails) + }) + }) +} + +func TestOpenAIDenyLegacyDefaultStream(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("default legacy stream response keeps deny body in first content frame", func(t *testing.T) { + host, status := test.NewTestHost(multiModalGuardTextConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "trigger deny"}], "stream": true}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-legacy-stream", "Data": {"RiskLevel": "high"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local) + + raw := string(local.Data) + require.True(t, strings.HasSuffix(strings.TrimSpace(raw), "data: [DONE]")) + parts := strings.Split(raw, "\n\n") + require.GreaterOrEqual(t, len(parts), 3) + + firstFrame := strings.TrimSpace(strings.TrimPrefix(parts[0], "data:")) + endFrame := strings.TrimSpace(strings.TrimPrefix(parts[1], "data:")) + + firstContent := gjson.Get(firstFrame, "choices.0.delta.content").String() + denyBody := mustDecodeLegacyDenyContent(t, firstContent) + require.Equal(t, 200, denyBody.Code) + require.NotEmpty(t, denyBody.BlockedDetails) + + require.False(t, gjson.Get(firstFrame, "choices.0.x_higress_guardrail").Exists()) + require.False(t, gjson.Get(firstFrame, "choices.0.x_higress").Exists()) + require.False(t, gjson.Get(endFrame, "choices.0.x_higress_guardrail").Exists()) + require.False(t, gjson.Get(endFrame, "choices.0.x_higress").Exists()) + require.Equal(t, "stop", gjson.Get(endFrame, "choices.0.finish_reason").String()) + }) + }) +} + +func TestOpenAIDenyLegacyDenyCodeKeepsResponseCodeInContent(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("legacy content code remains safecheck response code when denyCode differs", func(t *testing.T) { + host, status := test.NewTestHost(withConfigOverrides(multiModalGuardTextConfig, map[string]interface{}{ + "denyCode": 451, + })) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "trigger deny"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-legacy-451", "Data": {"RiskLevel": "high"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local) + require.Equal(t, uint32(451), local.StatusCode) + + content := gjson.GetBytes(local.Data, "choices.0.message.content").String() + denyBody := mustDecodeLegacyDenyContent(t, content) + require.Equal(t, 200, denyBody.Code) + }) + }) +} + +func TestMaskEmptyDesensitizationOpenAIFormats(t *testing.T) { + securityResponse := `{ + "Code": 200, "Message": "Success", "RequestId": "req-mask-empty-openai", + "Data": { + "RiskLevel": "none", + "Detail": [{ + "Suggestion": "mask", "Type": "sensitiveData", "Level": "S3", + "Result": [{"Label": "phone", "Confidence": 99.0, + "Ext": {"Desensitization": ""}}] + }] + } + }` + + test.RunTest(t, func(t *testing.T) { + t.Run("legacy empty desensitization uses json-stringified deny body", func(t *testing.T) { + host, status := test.NewTestHost(maskConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "敏感内容"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "application/json"}}, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local) + require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").Exists()) + require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists()) + + content := gjson.GetBytes(local.Data, "choices.0.message.content").String() + denyBody := mustDecodeLegacyDenyContent(t, content) + require.Equal(t, 200, denyBody.Code) + require.Empty(t, denyBody.BlockedDetails) + }) + + t.Run("structured empty desensitization uses fallback guardrail", func(t *testing.T) { + host, status := test.NewTestHost(withStructuredFormat(maskConfig)) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "敏感内容"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "application/json"}}, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local) + require.Equal(t, cfg.DefaultDenyMessage, gjson.GetBytes(local.Data, "choices.0.message.content").String()) + require.True(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").IsObject()) + require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists()) + require.Equal(t, int64(0), gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail.blockedDetails.#").Int()) + }) + }) +} + +func TestMaskReplaceJsonFieldFailureOpenAIFormats(t *testing.T) { + securityResponse := `{ + "Code": 200, "Message": "Success", "RequestId": "req-mask-replace-failure", + "Data": { + "RiskLevel": "none", + "Detail": [{ + "Suggestion": "mask", "Type": "sensitiveData", "Level": "S3", + "Result": [{"Label": "phone", "Confidence": 99.0, + "Ext": {"Desensitization": "masked"}}] + }] + } + }` + + test.RunTest(t, func(t *testing.T) { + t.Run("legacy replace failure keeps pure deny message content", func(t *testing.T) { + host, status := test.NewTestHost(withConfigOverrides(maskConfig, map[string]interface{}{ + "requestContentJsonPath": "@this.messages.0.content", + })) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "敏感内容"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "application/json"}}, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local) + require.Equal(t, cfg.DefaultDenyMessage, gjson.GetBytes(local.Data, "choices.0.message.content").String()) + require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").Exists()) + require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists()) + }) + + t.Run("structured replace failure emits fallback guardrail", func(t *testing.T) { + host, status := test.NewTestHost(withConfigOverrides(withStructuredFormat(maskConfig), map[string]interface{}{ + "requestContentJsonPath": "@this.messages.0.content", + })) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "敏感内容"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "application/json"}}, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local) + require.Equal(t, cfg.DefaultDenyMessage, gjson.GetBytes(local.Data, "choices.0.message.content").String()) + require.True(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").IsObject()) + require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists()) + require.Equal(t, int64(0), gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail.blockedDetails.#").Int()) + }) + }) +} + +// ============================================================================= +// x_higress_guardrail 扩展字段链路与模板测试 +// ============================================================================= + +// openAIChoiceWithGuardrail is the minimal OpenAI choice shape used by +// x_higress_guardrail assertions. Guardrail is unmarshaled directly into the strongly +// typed cfg.DenyResponseBody so tests assert the documented contract — code +// is int, denyMessage is string, blockedDetails is a slice — instead of +// silently tolerating shape drift through map[string]interface{}. +type openAIChoiceWithGuardrail struct { + Index int `json:"index"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + Delta struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"delta"` + FinishReason *string `json:"finish_reason"` + Guardrail cfg.DenyResponseBody `json:"x_higress_guardrail"` +} + +// openAIBodyWithGuardrail also carries top-level extension fields so tests can +// assert the design contract that x_higress_guardrail lives ONLY inside choices[0] — +// a regression where it leaks to the body root would deserialize into this +// field and the require.Empty check at the call site would fail. +type openAIBodyWithGuardrail struct { + Choices []openAIChoiceWithGuardrail `json:"choices"` + Guardrail *cfg.DenyResponseBody `json:"x_higress_guardrail,omitempty"` + XHigress *cfg.DenyResponseBody `json:"x_higress,omitempty"` +} + +// TestRequestDenyGuardrailNonStream verifies that 请求阶段非流式 deny renders +// content as plain text and embeds x_higress_guardrail as a JSON object. +func TestRequestDenyGuardrailNonStream(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("non-stream request deny carries guardrail object", func(t *testing.T) { + host, status := test.NewTestHost(withStructuredFormat(multiModalGuardTextConfig)) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "trigger deny"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-x-higress", "Data": {"RiskLevel": "high"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local) + + // x_higress_guardrail must be a JSON object, not a string + require.True(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").IsObject(), + "x_higress_guardrail should be an embedded JSON object, not a string literal") + require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists(), + "structured deny must not emit the old choices[0].x_higress field") + + var outer openAIBodyWithGuardrail + require.NoError(t, json.Unmarshal(local.Data, &outer)) + require.Len(t, outer.Choices, 1) + + require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Message.Content, + "content should carry only the human-readable deny text") + require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Guardrail.DenyMessage) + // A-R1F4 contract: x_higress_guardrail.code carries the gateway-emitted HTTP deny + // status (config.DenyCode, default 200), NOT the upstream security + // service's Response.Code. multiModalGuardTextConfig leaves denyCode + // unset, so this resolves to cfg.DefaultDenyCode. + require.Equal(t, int(cfg.DefaultDenyCode), outer.Choices[0].Guardrail.Code, + "x_higress_guardrail.code must equal the gateway-emitted HTTP deny status") + require.NotNil(t, outer.Choices[0].Guardrail.BlockedDetails) + + // Design contract: x_higress_guardrail lives ONLY nested under choices[0]. + require.Nil(t, outer.Guardrail, + "x_higress_guardrail must not leak to body root; only choices[0].x_higress_guardrail is valid") + require.Nil(t, outer.XHigress, "old x_higress must not leak to body root") + require.False(t, gjson.GetBytes(local.Data, "x_higress_guardrail").Exists(), + "x_higress_guardrail must not leak to body root; only choices[0].x_higress_guardrail is valid") + require.False(t, gjson.GetBytes(local.Data, "x_higress").Exists(), + "old x_higress must not leak to body root") + }) + }) +} + +// TestRequestDenyGuardrailStreamFrames verifies that 请求阶段流式 deny only +// embeds x_higress_guardrail in the final chunk and the first chunk carries plain text. +func TestRequestDenyGuardrailStreamFrames(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("stream request deny only attaches guardrail in last chunk", func(t *testing.T) { + host, status := test.NewTestHost(withStructuredFormat(multiModalGuardTextConfig)) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "trigger deny"}], "stream": true}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-stream-x-higress", "Data": {"RiskLevel": "high"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local) + + raw := string(local.Data) + require.True(t, strings.HasSuffix(strings.TrimSpace(raw), "data: [DONE]")) + parts := strings.Split(raw, "\n\n") + // 模板格式: data:\n\ndata:\n\ndata: [DONE] + require.GreaterOrEqual(t, len(parts), 3, "expected at least chunk + end + DONE") + + firstFrame := strings.TrimPrefix(parts[0], "data:") + endFrame := strings.TrimPrefix(parts[1], "data:") + + require.False(t, gjson.Get(firstFrame, "choices.0.x_higress_guardrail").Exists(), + "first chunk should not carry x_higress_guardrail") + require.False(t, gjson.Get(firstFrame, "choices.0.x_higress").Exists(), + "first chunk should not carry old x_higress") + require.Equal(t, cfg.DefaultDenyMessage, + gjson.Get(firstFrame, "choices.0.delta.content").String(), + "first chunk delta.content should be plain text") + + require.True(t, gjson.Get(endFrame, "choices.0.x_higress_guardrail").IsObject(), + "final chunk should carry x_higress_guardrail as object") + require.False(t, gjson.Get(endFrame, "choices.0.x_higress").Exists(), + "final chunk should not carry old x_higress") + // Deny stream's terminator carries `stop` for wire-level compatibility + // with downstream consumers (LangChain / LiteLLM / SDKs / BI) that key + // off `stop` as a valid completion. The moderation-event signal lives + // in choices[0].x_higress_guardrail (denyCode / blockedDetails) instead. + require.Equal(t, "stop", gjson.Get(endFrame, "choices.0.finish_reason").String()) + require.False(t, gjson.Get(endFrame, "choices.0.delta.content").Exists(), + "final chunk delta should be empty") + + // Design contract: x_higress_guardrail lives ONLY nested under choices[0]. + require.False(t, gjson.Get(endFrame, "x_higress_guardrail").Exists(), + "x_higress_guardrail must not leak to body root of the end frame") + require.False(t, gjson.Get(firstFrame, "x_higress_guardrail").Exists(), + "x_higress_guardrail must not leak to body root of the first frame") + require.False(t, gjson.Get(endFrame, "x_higress").Exists(), + "old x_higress must not leak to body root of the end frame") + require.False(t, gjson.Get(firstFrame, "x_higress").Exists(), + "old x_higress must not leak to body root of the first frame") + + require.Contains(t, parts[2], "[DONE]") + }) + }) +} + +// TestRequestDenyDefaultDenyMessage verifies that without a configured +// denyMessage, both content and x_higress_guardrail.denyMessage fall back to the default +// text. +func TestRequestDenyDefaultDenyMessage(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("default deny message used when not configured", func(t *testing.T) { + host, status := test.NewTestHost(withStructuredFormat(multiModalGuardTextConfig)) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "trigger deny"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-default", "Data": {"RiskLevel": "high"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local) + + content := gjson.GetBytes(local.Data, "choices.0.message.content").String() + denyMessage := gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail.denyMessage").String() + require.Equal(t, cfg.DefaultDenyMessage, content) + require.Equal(t, cfg.DefaultDenyMessage, denyMessage) + require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists()) + }) + }) +} + +// TestProtocolOriginalDenyShapePreserved guards the regression that the +// protocol: "original" normal deny path keeps the bare DenyResponseBody shape +// without OpenAI wrapping or x_higress_guardrail. +func TestProtocolOriginalDenyShapePreserved(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + for _, tc := range []struct { + name string + config json.RawMessage + }{ + {name: "default format", config: protocolOriginalConfig}, + {name: "legacy format", config: withConfigOverrides(protocolOriginalConfig, map[string]interface{}{"openAIDenyResponseFormat": "legacy"})}, + {name: "structured format", config: withStructuredFormat(protocolOriginalConfig)}, + } { + tc := tc + t.Run("original protocol deny stays as bare DenyResponseBody "+tc.name, func(t *testing.T) { + host, status := test.NewTestHost(tc.config) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "trigger deny"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-orig-shape", "Data": {"RiskLevel": "high"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local) + + require.False(t, gjson.GetBytes(local.Data, "choices").Exists(), + "original protocol should not OpenAI-wrap the body") + require.False(t, gjson.GetBytes(local.Data, "x_higress_guardrail").Exists(), + "original protocol should not introduce x_higress_guardrail") + require.False(t, gjson.GetBytes(local.Data, "x_higress").Exists(), + "original protocol should not introduce old x_higress") + require.True(t, gjson.GetBytes(local.Data, "blockedDetails").Exists()) + require.Equal(t, int64(200), gjson.GetBytes(local.Data, "code").Int()) + }) + } + }) +} + +// TestMaskEmptyDesensitizationOriginalShape guards the A-R1F2 alignment: +// when riskAction=mask and the upstream returns empty Desensitization under +// protocol: "original", the response body must be the bare JSON string literal +// produced by wrapper.MarshalStr(ResolveDenyMessage(config)) — i.e. +// `""` — mirroring the ReplaceJsonFieldTextContent failure path +// at lvwang/multi_modal_guard/text/openai.go:102 / :159. +// +// Before A-R1F2 this branch fell through to RiskBlock and called +// BuildDenyResponseBody, returning a structured {code, blockedDetails, ...} +// object instead. The fallthrough was inconsistent with design Section 5 and +// has been replaced with the self-handled MarshalStr path; this regression +// test locks in the new contract so the divergence cannot reappear silently. +func TestMaskEmptyDesensitizationOriginalShape(t *testing.T) { + maskOriginalConfig := func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "checkResponse": false, + "action": "MultiModalGuard", + "riskAction": "mask", + "protocol": "original", + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "sensitiveDataLevelBar": "S3", + "timeout": 2000, + }) + return data + }() + + test.RunTest(t, func(t *testing.T) { + t.Run("mask empty desensitization under original emits MarshalStr literal", func(t *testing.T) { + host, status := test.NewTestHost(maskOriginalConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "敏感内容"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + // mask 但脱敏内容空 → 走 A-R1F2 自处理分支 + securityResponse := `{ + "Code": 200, "Message": "Success", "RequestId": "req-mask-empty-orig", + "Data": { + "RiskLevel": "none", + "Detail": [{ + "Suggestion": "mask", "Type": "sensitiveData", "Level": "S3", + "Result": [{"Label": "phone", "Confidence": 99.0, + "Ext": {"Desensitization": ""}}] + }] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local) + + // 不是 OpenAI 包装 + require.False(t, gjson.GetBytes(local.Data, "choices").Exists(), + "mask empty-desensitization under original must not be OpenAI-wrapped") + require.False(t, gjson.GetBytes(local.Data, "object").Exists(), + "mask empty-desensitization under original must not include OpenAI 'object' field") + + // 不再是 DenyResponseBody {code, blockedDetails} 结构; + // 现在是 wrapper.MarshalStr 产物 —— 裸 JSON 字符串字面量 + require.False(t, gjson.GetBytes(local.Data, "code").Exists(), + "A-R1F2: body should now be a JSON string literal, not a {code, blockedDetails} object") + require.False(t, gjson.GetBytes(local.Data, "blockedDetails").Exists(), + "A-R1F2: body should now be a JSON string literal, not a {code, blockedDetails} object") + + // 实际形态:wrapper.MarshalStr 产物。该 wrapper 在 Higress 的实现里返回 + // 已剥除外层双引号的字符串(见 C-R1F9 备注),所以 body 是原始 deny 文本 + // 字节(不可 json.Unmarshal 回字符串)。 + require.Equal(t, cfg.DefaultDenyMessage, string(local.Data), + "body should equal raw deny message (wrapper.MarshalStr strips outer quotes — C-R1F9)") + }) + }) +} + +// TestResponseStreamingDenyGuardrail drives HandleTextGenerationStreamingResponseBody +// (lvwang/common/text/openai.go) — the only response-side structured stream +// writer — and asserts the injected SSE carries: +// - first chunk: human-readable content, no x_higress_guardrail +// - end chunk: x_higress_guardrail as a JSON object with code/denyMessage/blockedDetails +// - terminator: "data: [DONE]" +func TestResponseStreamingDenyGuardrail(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("response streaming deny injects guardrail only in last frame", func(t *testing.T) { + host, status := test.NewTestHost(withStructuredFormat(basicConfig)) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + // Skip request-phase check by using a non-deny request body. + body := `{"messages": [{"role": "user", "content": "hello"}]}` + host.CallOnHttpRequestBody([]byte(body)) + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(`{"Code": 200, "Message": "Success", "RequestId": "req-stream-resp-pass", "Data": {"RiskLevel": "none"}}`)) + + action := host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "text/event-stream"}, + }) + require.Equal(t, types.ActionContinue, action) + + // Single chunk + end_of_stream=true triggers the security check. + chunk := []byte("data: {\"choices\":[{\"delta\":{\"content\":\"bad response\"}}]}\n\n") + host.CallOnHttpStreamingResponseBody(chunk, true) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-stream-resp-deny", "Data": {"RiskLevel": "high"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + injected := host.GetResponseBody() + require.NotEmpty(t, injected, "expected InjectEncodedDataToFilterChain to deliver SSE deny payload") + + injectedStr := string(injected) + require.True(t, strings.HasSuffix(strings.TrimSpace(injectedStr), "data: [DONE]"), + "expected SSE stream to end with [DONE], got: %s", injectedStr) + + // Strip "data:" prefixes and split into events. + events := strings.Split(strings.TrimSpace(injectedStr), "\n\n") + require.GreaterOrEqual(t, len(events), 2, "expected at least first chunk + end chunk") + + // First event: content present, no x_higress_guardrail. + firstPayload := strings.TrimPrefix(events[0], "data:") + firstPayload = strings.TrimSpace(firstPayload) + require.Equal(t, cfg.DefaultDenyMessage, gjson.Get(firstPayload, "choices.0.delta.content").String()) + require.False(t, gjson.Get(firstPayload, "choices.0.x_higress_guardrail").Exists(), + "first chunk must NOT carry x_higress_guardrail") + require.False(t, gjson.Get(firstPayload, "choices.0.x_higress").Exists(), + "first chunk must NOT carry old x_higress") + + // Second event: x_higress_guardrail as JSON object on choices[0]. + secondPayload := strings.TrimPrefix(events[1], "data:") + secondPayload = strings.TrimSpace(secondPayload) + guardrail := gjson.Get(secondPayload, "choices.0.x_higress_guardrail") + require.True(t, guardrail.Exists(), "end chunk must carry x_higress_guardrail") + require.True(t, guardrail.IsObject(), "x_higress_guardrail must be a JSON object, not a string") + require.Equal(t, cfg.DefaultDenyMessage, guardrail.Get("denyMessage").String()) + require.True(t, guardrail.Get("blockedDetails").Exists()) + require.False(t, gjson.Get(secondPayload, "choices.0.x_higress").Exists(), + "end chunk must not carry old x_higress") + // Streaming deny terminator carries `stop` for wire-level compatibility; + // moderation signal lives in choices[0].x_higress_guardrail. + require.Equal(t, "stop", gjson.Get(secondPayload, "choices.0.finish_reason").String()) + + // Design contract: x_higress_guardrail lives ONLY nested under choices[0]. + require.False(t, gjson.Get(secondPayload, "x_higress_guardrail").Exists(), + "x_higress_guardrail must not leak to body root of the end chunk") + require.False(t, gjson.Get(firstPayload, "x_higress_guardrail").Exists(), + "x_higress_guardrail must not leak to body root of the first chunk") + require.False(t, gjson.Get(secondPayload, "x_higress").Exists(), + "old x_higress must not leak to body root of the end chunk") + require.False(t, gjson.Get(firstPayload, "x_higress").Exists(), + "old x_higress must not leak to body root of the first chunk") + }) + }) +} + +// A-R2-16(a): multi_modal_guard 图像审核 deny 通道未被前文 guardrail 测试覆盖, +// 而 R1-F6 修复(BuildOpenAIFallbackDenyResponseBody err 不再静默)依赖 +// callbackForImage 与 singleCallForImage 调用 BuildOpenAIDenyResponseBody。 +// 本测试发送纯 image_url 请求体直接命中 singleCallForImage 路径,断言图像 +// 审核 deny 同样把 x_higress_guardrail 嵌入 choices[0],与文本 deny 形态对称。 +// +// 文件:lvwang/multi_modal_guard/text/openai.go:299-369(callbackForImage / OpenAI 包装) +func TestImageDenyGuardrailShape(t *testing.T) { + imageCheckConfig := func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "checkRequestImage": true, + "action": "MultiModalGuard", + "apiType": "text_generation", + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "sensitiveDataLevelBar": "S3", + "timeout": 2000, + "bufferLimit": 1000, + }) + return data + }() + + test.RunTest(t, func(t *testing.T) { + t.Run("multi_modal_guard image deny embeds guardrail on choices[0]", func(t *testing.T) { + host, status := test.NewTestHost(withStructuredFormat(imageCheckConfig)) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + // 纯 image_url 内容:content=="",parseContent 跳过文本审核直接走 + // singleCallForImage,让 callbackForImage 渲染 deny。 + body := `{"messages":[{"role":"user","content":[{"type":"image_url","image_url":{"url":"https://example.com/bad.jpg"}}]}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-deny", "Data": {"RiskLevel": "high"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local, "expected SendHttpResponse for image deny") + + // x_higress_guardrail 必须作为对象嵌在 choices[0] 内 + require.True(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").IsObject(), + "image deny should embed x_higress_guardrail as JSON object on choices[0]") + require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists(), + "image deny should not emit old x_higress on choices[0]") + + var outer openAIBodyWithGuardrail + require.NoError(t, json.Unmarshal(local.Data, &outer)) + require.Len(t, outer.Choices, 1) + + require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Message.Content, + "content carries human-readable deny text") + require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Guardrail.DenyMessage) + require.Equal(t, int(cfg.DefaultDenyCode), outer.Choices[0].Guardrail.Code, + "x_higress_guardrail.code = gateway HTTP deny status (A-R1F4 contract)") + require.NotNil(t, outer.Choices[0].Guardrail.BlockedDetails) + + // 不可泄漏到 body 根 + require.Nil(t, outer.Guardrail, "x_higress_guardrail must not leak to body root") + require.Nil(t, outer.XHigress, "old x_higress must not leak to body root") + require.False(t, gjson.GetBytes(local.Data, "x_higress_guardrail").Exists(), + "x_higress_guardrail must not leak to body root") + require.False(t, gjson.GetBytes(local.Data, "x_higress").Exists(), + "old x_higress must not leak to body root") + }) + }) +} + +// A-R2-16(b): text_moderation_plus 请求阶段 deny 通道未被前文 guardrail 测试覆盖。 +// 已有 TestTextModerationPlusResponseDeny 覆盖响应阶段,本测试补齐请求阶段对称 +// 用例,确保 OpenAI 协议下 structured deny body 把 x_higress_guardrail 放在 choices[0]。 +// +// 文件:lvwang/text_moderation_plus/text/openai.go:56-92(请求阶段 deny 渲染) +func TestTextModerationPlusRequestDenyGuardrailShape(t *testing.T) { + tmpRequestConfig := func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "action": "TextModerationPlus", + "contentModerationLevelBar": "high", + "timeout": 2000, + }) + return data + }() + + test.RunTest(t, func(t *testing.T) { + t.Run("text_moderation_plus request deny embeds guardrail on choices[0]", func(t *testing.T) { + host, status := test.NewTestHost(withStructuredFormat(tmpRequestConfig)) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "trigger deny"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-tmp-deny", "Data": {"RiskLevel": "high"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local, "expected SendHttpResponse for text_moderation_plus request deny") + + require.True(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").IsObject(), + "text_moderation_plus request deny should embed x_higress_guardrail as JSON object on choices[0]") + require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists(), + "text_moderation_plus request deny should not emit old x_higress") + + var outer openAIBodyWithGuardrail + require.NoError(t, json.Unmarshal(local.Data, &outer)) + require.Len(t, outer.Choices, 1) + + require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Message.Content) + require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Guardrail.DenyMessage) + require.Equal(t, int(cfg.DefaultDenyCode), outer.Choices[0].Guardrail.Code, + "x_higress_guardrail.code = gateway HTTP deny status (A-R1F4 contract)") + require.NotNil(t, outer.Choices[0].Guardrail.BlockedDetails) + + require.Nil(t, outer.Guardrail, "x_higress_guardrail must not leak to body root") + require.Nil(t, outer.XHigress, "old x_higress must not leak to body root") + require.False(t, gjson.GetBytes(local.Data, "x_higress_guardrail").Exists(), + "x_higress_guardrail must not leak to body root") + require.False(t, gjson.GetBytes(local.Data, "x_higress").Exists(), + "old x_higress must not leak to body root") + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/utils/utils.go b/plugins/wasm-go/extensions/ai-security-guard/utils/utils.go index f8edd0d7d..3542c618f 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/utils/utils.go +++ b/plugins/wasm-go/extensions/ai-security-guard/utils/utils.go @@ -58,7 +58,7 @@ func ReplaceJsonFieldTextContent(body []byte, jsonPath string, newContent string fieldValue := gjson.GetBytes(body, resolved) if !fieldValue.IsArray() { // Simple string content — replace directly - return sjson.SetBytes(body, resolved, newContent) + return setJsonTextContent(body, resolved, newContent) } // Array content (multimodal): replace text items, preserve others result := body @@ -86,7 +86,7 @@ func ReplaceJsonFieldTextContent(body []byte, jsonPath string, newContent string // If there's only one text item, put all desensitized content there if len(textEntries) == 1 { itemPath := fmt.Sprintf("%s.%d.text", resolved, textEntries[0].index) - return sjson.SetBytes(result, itemPath, newContent) + return setJsonTextContent(result, itemPath, newContent) } // Multiple text items: split desensitized content proportionally by original lengths for j, entry := range textEntries { @@ -117,7 +117,7 @@ func ReplaceJsonFieldTextContent(body []byte, jsonPath string, newContent string remaining = remaining[byteOffset:] } itemPath := fmt.Sprintf("%s.%d.text", resolved, entry.index) - result, err = sjson.SetBytes(result, itemPath, replacement) + result, err = setJsonTextContent(result, itemPath, replacement) if err != nil { return nil, err } @@ -125,6 +125,18 @@ func ReplaceJsonFieldTextContent(body []byte, jsonPath string, newContent string return result, nil } +func setJsonTextContent(body []byte, jsonPath string, newContent string) ([]byte, error) { + current := gjson.GetBytes(body, jsonPath) + result, err := sjson.SetBytes(body, jsonPath, newContent) + if err != nil { + return nil, err + } + if current.Exists() && current.String() != newContent && bytes.Equal(result, body) { + return nil, fmt.Errorf("failed to replace json path %q", jsonPath) + } + return result, nil +} + // resolveJsonPath converts gjson modifier paths (e.g. "messages.@reverse.0.content") // into concrete index paths (e.g. "messages.2.content") that sjson can handle. func resolveJsonPath(body []byte, jsonPath string) string { diff --git a/plugins/wasm-go/extensions/ai-security-guard/utils/utils_test.go b/plugins/wasm-go/extensions/ai-security-guard/utils/utils_test.go index 6454b8b32..e63b67f03 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/utils/utils_test.go +++ b/plugins/wasm-go/extensions/ai-security-guard/utils/utils_test.go @@ -263,6 +263,14 @@ func TestResolveJsonPathEdgeCases(t *testing.T) { }) } +func TestReplaceJsonFieldTextContentReportsReadableNoOp(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":"敏感内容"}]}`) + result, err := ReplaceJsonFieldTextContent(body, "@this.messages.0.content", "masked") + if err == nil { + t.Fatalf("expected error for readable path that sjson leaves unchanged, got nil with %s", string(result)) + } +} + // TestReplaceJsonFieldContent covers the simple ReplaceJsonFieldContent function func TestReplaceJsonFieldContent(t *testing.T) { body := []byte(`{"messages":[{"role":"user","content":"original"}]}`)