From b1187cc14d08faa2ce66eb6a9934bb205e6c7cdf Mon Sep 17 00:00:00 2001 From: JianweiWang Date: Wed, 15 Apr 2026 11:14:56 +0800 Subject: [PATCH] feat(ai-security-guard): enhance risk action resolution and support sensitive data masking (#3690) Co-authored-by: rinfx --- .../extensions/ai-security-guard/README.md | 12 + .../extensions/ai-security-guard/README_EN.md | 26 + .../config/action_resolver_test.go | 365 +++ .../ai-security-guard/config/config.go | 508 ++++- .../config/evaluate_risk_property_test.go | 648 ++++++ .../config/evaluate_risk_test.go | 1109 +++++++++ .../extensions/ai-security-guard/go.mod | 2 +- .../lvwang/multi_modal_guard/text/openai.go | 175 +- .../extensions/ai-security-guard/main_test.go | 1996 ++++++++++++++++- .../ai-security-guard/utils/utils.go | 109 + .../ai-security-guard/utils/utils_test.go | 277 +++ plugins/wasm-go/mcp-servers/amap-tools/go.sum | 3 + .../wasm-go/mcp-servers/quark-search/go.sum | 3 + 13 files changed, 5019 insertions(+), 214 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-security-guard/config/action_resolver_test.go create mode 100644 plugins/wasm-go/extensions/ai-security-guard/config/evaluate_risk_property_test.go create mode 100644 plugins/wasm-go/extensions/ai-security-guard/config/evaluate_risk_test.go create mode 100644 plugins/wasm-go/extensions/ai-security-guard/utils/utils_test.go diff --git a/plugins/wasm-go/extensions/ai-security-guard/README.md b/plugins/wasm-go/extensions/ai-security-guard/README.md index cfc13d0c5..885fbdc4d 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README.md @@ -35,6 +35,8 @@ description: 阿里云内容安全检测 | `contentModerationLevelBar` | string | optional | max | 内容合规检测拦截风险等级,取值为 `max`, `high`, `medium` or `low` | | `promptAttackLevelBar` | string | optional | max | 提示词攻击检测拦截风险等级,取值为 `max`, `high`, `medium` or `low` | | `sensitiveDataLevelBar` | string | optional | S4 | 敏感内容检测拦截风险等级,取值为 `S4`, `S3`, `S2` or `S1` | +| `customLabelLevelBar` | string | optional | max | 自定义检测拦截风险等级,取值为 max, high, medium, low | +| `riskAction` | string | optional | block | 风险处置动作,取值为 `block` 或 `mask`。`block` 表示按风险等级阈值拦截请求,`mask` 表示当 API 返回脱敏建议时使用脱敏内容替换敏感字段。注意:脱敏功能仅适用于 MultiModalGuard 模式 | | `timeout` | int | optional | 2000 | 调用内容安全服务时的超时时间 | | `bufferLimit` | int | optional | 1000 | 调用内容安全服务时每段文本的长度限制 | | `consumerRequestCheckService` | map | optional | - | 为不同消费者指定特定的请求检测服务 | @@ -93,6 +95,16 @@ description: 阿里云内容安全检测 - `S2`: 敏感内容检测结果中风险等级 >= `S2` 时产生拦截 - `S1`: 敏感内容检测结果中风险等级 >= `S1` 时产生拦截 +- 对于自定义检测(customLabel): + - `max`: 检测请求/响应内容,但是不会产生拦截行为 + - `high`: 自定义检测结果中风险等级为 `high` 时产生拦截 + - 注意:阿里云 API 对 customLabel 维度仅返回 `high` 和 `none` 两个等级,不同于其他维度的四级划分。配置为 `high` 即可在检测命中时拦截,配置为 `max` 则不拦截。`medium` 和 `low` 为配置兼容性保留,但 API 不会返回这些等级。 + +- 对于风险处置动作(riskAction): + - `block`: 按各维度的风险等级阈值判断是否拦截 + - `mask`: 当 API 返回 `Suggestion=mask` 时使用脱敏内容替换敏感字段,当 `Suggestion=block` 时仍会拦截 + - 注意:脱敏功能仅适用于 MultiModalGuard 模式(action 配置为 MultiModalGuard),其他模式不支持脱敏 + ## 配置示例 ### 前提条件 由于插件中需要调用阿里云内容安全服务,所以需要先创建一个DNS类型的服务,例如: diff --git a/plugins/wasm-go/extensions/ai-security-guard/README_EN.md b/plugins/wasm-go/extensions/ai-security-guard/README_EN.md index 9afdbf934..61849763c 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README_EN.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README_EN.md @@ -35,12 +35,38 @@ Plugin Priority: `300` | `contentModerationLevelBar` | string | optional | max | contentModeration risk level threshold, `max`, `high`, `medium` or `low` | | `promptAttackLevelBar` | string | optional | max | promptAttack risk level threshold, `max`, `high`, `medium` or `low` | | `sensitiveDataLevelBar` | string | optional | S4 | sensitiveData risk level threshold, `S4`, `S3`, `S2` or `S1` | +| `customLabelLevelBar` | string | optional | max | Custom label detection risk level threshold, value can be max, high, medium, or low | +| `riskAction` | string | optional | block | Risk action, value can be `block` or `mask`. `block` means blocking requests based on risk level thresholds, `mask` means replacing sensitive fields with desensitized content when API returns mask suggestion. Note: masking only works with MultiModalGuard mode | | `timeout` | int | optional | 2000 | timeout for lvwang service | | `bufferLimit` | int | optional | 1000 | Limit the length of each text when calling the lvwang service | | `consumerRequestCheckService` | map | optional | - | Specify specific request detection services for different consumers | | `consumerResponseCheckService` | map | optional | - | Specify specific response detection services for different consumers | | `consumerRiskLevel` | map | optional | - | Specify interception risk levels for different consumers in different dimensions | +Risk level explanations for each detection dimension: + +- For content moderation and prompt attack detection (contentModeration, promptAttack): + - `max`: Detect request/response content but do not block + - `high`: Block when risk level is `high` + - `medium`: Block when risk level >= `medium` + - `low`: Block when risk level >= `low` + +- For sensitive data detection (sensitiveData): + - `S4`: Detect request/response content but do not block + - `S3`: Block when risk level is `S3` + - `S2`: Block when risk level >= `S2` + - `S1`: Block when risk level >= `S1` + +- For custom label detection (customLabel): + - `max`: Detect request/response content but do not block + - `high`: Block when custom label detection result risk level is `high` + - Note: The Alibaba Cloud API only returns `high` and `none` for the customLabel dimension, unlike other dimensions which have four levels. Set to `high` to block on detection hit, set to `max` to not block. `medium` and `low` are kept for configuration compatibility but will not be returned by the API. + +- For risk action (riskAction): + - `block`: Block requests based on risk level thresholds for each dimension + - `mask`: Replace sensitive fields with desensitized content when API returns `Suggestion=mask`, still block when `Suggestion=block` + - Note: Masking only works with MultiModalGuard mode (action configured as MultiModalGuard), other modes do not support masking + ### Deny Response Body When content is blocked, the plugin (`MultiModalGuard` action) returns the following structured JSON object. The location in the response depends on the protocol: diff --git a/plugins/wasm-go/extensions/ai-security-guard/config/action_resolver_test.go b/plugins/wasm-go/extensions/ai-security-guard/config/action_resolver_test.go new file mode 100644 index 000000000..e197bfa36 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/config/action_resolver_test.go @@ -0,0 +1,365 @@ +// Copyright (c) 2024 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "os" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/proxytest" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/stretchr/testify/require" +) + +// testVMContext is a minimal VMContext for setting up the proxy-wasm mock host. +type testVMContext struct { + types.DefaultVMContext +} + +// TestMain sets up the proxy-wasm mock host for all tests in the config package. +// This is required because functions like enforceMaskBoundary call proxywasm.LogWarnf. +func TestMain(m *testing.M) { + opt := proxytest.NewEmulatorOption().WithVMContext(&testVMContext{}) + _, reset := proxytest.NewHostEmulator(opt) + defer reset() + os.Exit(m.Run()) +} + +// ============================================================================= +// TC-RESOLVE: 动作解析优先级测试(ResolveRiskActionByType) +// ============================================================================= + +// TestTC_RESOLVE_001 仅全局 riskAction=mask,无维度动作 +// => sensitiveData 返回 mask,非 sensitiveData 维度降级为 block +func TestTC_RESOLVE_001(t *testing.T) { + config := AISecurityConfig{ + RiskAction: "mask", + } + + // sensitiveData 返回 mask(source=global_global) + action, source := config.ResolveRiskActionByType("", SensitiveDataType) + require.Equal(t, "mask", action) + require.Equal(t, "global_global", source) + + // promptAttack 降级为 block(source=global_global) + action, source = config.ResolveRiskActionByType("", PromptAttackType) + require.Equal(t, "block", action) + require.Equal(t, "global_global", source) + + // contentModeration 降级为 block + action, source = config.ResolveRiskActionByType("", ContentModerationType) + require.Equal(t, "block", action) + require.Equal(t, "global_global", source) + + // maliciousUrl 降级为 block + action, source = config.ResolveRiskActionByType("", MaliciousUrlDataType) + require.Equal(t, "block", action) + require.Equal(t, "global_global", source) +} + +// TestTC_RESOLVE_002 全局 riskAction=mask + 全局 promptAttackAction=block +// => promptAttack 返回 block,sensitiveData 返回 mask +func TestTC_RESOLVE_002(t *testing.T) { + config := AISecurityConfig{ + RiskAction: "mask", + PromptAttackAction: "block", + } + + // promptAttack 返回 block(source=global_dimension) + action, source := config.ResolveRiskActionByType("", PromptAttackType) + require.Equal(t, "block", action) + require.Equal(t, "global_dimension", source) + + // sensitiveData 返回 mask(source=global_global) + action, source = config.ResolveRiskActionByType("", SensitiveDataType) + require.Equal(t, "mask", action) + require.Equal(t, "global_global", source) +} + +// TestTC_RESOLVE_003 consumer 规则含 riskAction=block,全局 sensitiveDataAction=mask +// => sensitiveData 返回 block(consumer_global 优先于 global_dimension) +func TestTC_RESOLVE_003(t *testing.T) { + config := AISecurityConfig{ + RiskAction: "mask", + SensitiveDataAction: "mask", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": Matcher{Exact: "user-a"}, + "riskAction": "block", + }, + }, + } + + // consumer_global(block) 优先于 global_dimension(mask) + action, source := config.ResolveRiskActionByType("user-a", SensitiveDataType) + require.Equal(t, "block", action) + require.Equal(t, "consumer_global", source) + + // 未命中 consumer 规则时,回退到 global_dimension + action, source = config.ResolveRiskActionByType("user-b", SensitiveDataType) + require.Equal(t, "mask", action) + require.Equal(t, "global_dimension", source) +} + +// TestTC_RESOLVE_004 consumer 规则含 sensitiveDataAction=mask 且 riskAction=block +// => sensitiveData 返回 mask(consumer_dimension 优先) +func TestTC_RESOLVE_004(t *testing.T) { + config := AISecurityConfig{ + RiskAction: "block", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": Matcher{Exact: "user-a"}, + "riskAction": "block", + "sensitiveDataAction": "mask", + }, + }, + } + + // consumer_dimension(mask) 优先于 consumer_global(block) + action, source := config.ResolveRiskActionByType("user-a", SensitiveDataType) + require.Equal(t, "mask", action) + require.Equal(t, "consumer_dimension", source) + + // promptAttack 无 consumer_dimension,回退到 consumer_global(block) + action, source = config.ResolveRiskActionByType("user-a", PromptAttackType) + require.Equal(t, "block", action) + require.Equal(t, "consumer_global", source) +} + +// TestTC_RESOLVE_005 都未配置 => 返回 block(source=default) +func TestTC_RESOLVE_005(t *testing.T) { + config := AISecurityConfig{} + + action, source := config.ResolveRiskActionByType("", SensitiveDataType) + require.Equal(t, "block", action) + require.Equal(t, "default", source) + + action, source = config.ResolveRiskActionByType("", PromptAttackType) + require.Equal(t, "block", action) + require.Equal(t, "default", source) +} + +// ============================================================================= +// TC-MATCH: first-match 语义测试(getMatchedConsumerRiskRule) +// ============================================================================= + +// TestTC_MATCH_001 两条规则都可命中(prefix + exact),prefix 在前 => 命中 prefix +func TestTC_MATCH_001(t *testing.T) { + config := AISecurityConfig{ + RiskAction: "block", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": Matcher{Prefix: "user-"}, + "sensitiveDataAction": "mask", + }, + { + "matcher": Matcher{Exact: "user-a"}, + "sensitiveDataAction": "block", + }, + }, + } + + // "user-a" 同时匹配 prefix("user-") 和 exact("user-a"),但 prefix 在前 + action, source := config.ResolveRiskActionByType("user-a", SensitiveDataType) + require.Equal(t, "mask", action) + require.Equal(t, "consumer_dimension", source) +} + +// TestTC_MATCH_002 首条命中但未配置某维度动作,第二条配置了 => 不读取第二条,回退全局 +func TestTC_MATCH_002(t *testing.T) { + config := AISecurityConfig{ + RiskAction: "mask", + PromptAttackAction: "block", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": Matcher{Prefix: "user-"}, + "riskAction": "mask", + // 未配置 promptAttackAction + }, + { + "matcher": Matcher{Exact: "user-a"}, + "promptAttackAction": "block", + }, + }, + } + + // "user-a" 命中首条 prefix 规则,promptAttackAction 未配置 + // 回退到 consumer_global(mask),然后 enforceMaskBoundary 降级为 block + action, source := config.ResolveRiskActionByType("user-a", PromptAttackType) + require.Equal(t, "block", action) + require.Equal(t, "consumer_global", source) +} + +// TestTC_MATCH_003 无规则命中 => 回退全局 +func TestTC_MATCH_003(t *testing.T) { + config := AISecurityConfig{ + RiskAction: "mask", + SensitiveDataAction: "mask", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": Matcher{Exact: "vip-user"}, + "sensitiveDataAction": "block", + }, + }, + } + + // "other-user" 不匹配任何规则,回退到 global_dimension + action, source := config.ResolveRiskActionByType("other-user", SensitiveDataType) + require.Equal(t, "mask", action) + require.Equal(t, "global_dimension", source) + + // promptAttack 无 global_dimension,回退到 global_global(mask),降级为 block + action, source = config.ResolveRiskActionByType("other-user", PromptAttackType) + require.Equal(t, "block", action) + require.Equal(t, "global_global", source) +} + +// ============================================================================= +// 补充边界测试 +// ============================================================================= + +// TestTC_RESOLVE_006 consumer 规则中 promptAttackAction=mask => 降级为 block +func TestTC_RESOLVE_006(t *testing.T) { + config := AISecurityConfig{ + RiskAction: "block", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": Matcher{Exact: "user-a"}, + "promptAttackAction": "mask", // 非 sensitiveData 维度配置 mask + }, + }, + } + + // consumer_dimension(mask) 降级为 block + action, source := config.ResolveRiskActionByType("user-a", PromptAttackType) + require.Equal(t, "block", action) + require.Equal(t, "consumer_dimension", source) +} + +// TestTC_RESOLVE_007 consumer 规则中 contentModerationAction=mask => 降级为 block +func TestTC_RESOLVE_007(t *testing.T) { + config := AISecurityConfig{ + RiskAction: "block", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": Matcher{Exact: "user-a"}, + "contentModerationAction": "mask", + }, + }, + } + + action, source := config.ResolveRiskActionByType("user-a", ContentModerationType) + require.Equal(t, "block", action) + require.Equal(t, "consumer_dimension", source) +} + +// TestTC_RESOLVE_008 consumer 规则中 maliciousUrlAction=mask => 降级为 block +func TestTC_RESOLVE_008(t *testing.T) { + config := AISecurityConfig{ + RiskAction: "block", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": Matcher{Exact: "user-a"}, + "maliciousUrlAction": "mask", + }, + }, + } + + action, source := config.ResolveRiskActionByType("user-a", MaliciousUrlDataType) + require.Equal(t, "block", action) + require.Equal(t, "consumer_dimension", source) +} + +// TestTC_RESOLVE_009 未知 detailType => dimensionActionKey 返回空,跳过 consumer_dimension +func TestTC_RESOLVE_009(t *testing.T) { + config := AISecurityConfig{ + RiskAction: "block", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": Matcher{Exact: "user-a"}, + "riskAction": "mask", + }, + }, + } + + // 未知 Type,dimKey 为空,跳过 consumer_dimension,回退到 consumer_global(mask) + // 非 sensitiveData 维度的 mask 降级为 block + action, source := config.ResolveRiskActionByType("user-a", "unknownType") + require.Equal(t, "block", action) + require.Equal(t, "consumer_global", source) +} + +// TestTC_RESOLVE_010 未知 detailType + 无 consumer 匹配 => 回退到 global_global +func TestTC_RESOLVE_010(t *testing.T) { + config := AISecurityConfig{ + RiskAction: "mask", + } + + // 未知 Type,无 consumer 匹配,回退到 global_global(mask) + // 非 sensitiveData 维度的 mask 降级为 block + action, source := config.ResolveRiskActionByType("", "unknownType") + require.Equal(t, "block", action) + require.Equal(t, "global_global", source) +} + +// TestTC_RESOLVE_011 所有 6 个维度的 global dimension action 正确映射 +func TestTC_RESOLVE_011(t *testing.T) { + config := AISecurityConfig{ + ContentModerationAction: "block", + PromptAttackAction: "block", + SensitiveDataAction: "mask", + MaliciousUrlAction: "block", + ModelHallucinationAction: "block", + CustomLabelAction: "block", + } + + tests := []struct { + detailType string + expectedAction string + expectedSource string + }{ + {ContentModerationType, "block", "global_dimension"}, + {PromptAttackType, "block", "global_dimension"}, + {SensitiveDataType, "mask", "global_dimension"}, + {MaliciousUrlDataType, "block", "global_dimension"}, + {ModelHallucinationDataType, "block", "global_dimension"}, + {CustomLabelType, "block", "global_dimension"}, + } + + for _, tt := range tests { + action, source := config.ResolveRiskActionByType("", tt.detailType) + require.Equal(t, tt.expectedAction, action, "detailType=%s", tt.detailType) + require.Equal(t, tt.expectedSource, source, "detailType=%s", tt.detailType) + } +} + +// TestTC_MATCH_004 空 consumer 不匹配 exact/prefix 规则 => 回退全局 +func TestTC_MATCH_004(t *testing.T) { + config := AISecurityConfig{ + RiskAction: "mask", + SensitiveDataAction: "block", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": Matcher{Exact: "vip"}, + "sensitiveDataAction": "mask", + }, + }, + } + + // 空 consumer 不匹配 exact("vip"),回退到 global_dimension(block) + action, source := config.ResolveRiskActionByType("", SensitiveDataType) + require.Equal(t, "block", action) + require.Equal(t, "global_dimension", source) +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/config/config.go b/plugins/wasm-go/extensions/ai-security-guard/config/config.go index 175805a14..7cc8f865a 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/config/config.go +++ b/plugins/wasm-go/extensions/ai-security-guard/config/config.go @@ -30,6 +30,9 @@ const ( SensitiveDataType = "sensitiveData" MaliciousUrlDataType = "maliciousUrl" ModelHallucinationDataType = "modelHallucination" + CustomLabelType = "customLabel" + MaliciousFileType = "maliciousFile" + WaterMarkType = "waterMark" // Default configurations OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` @@ -89,16 +92,23 @@ type Response struct { type Data struct { RiskLevel string `json:"RiskLevel,omitempty"` AttackLevel string `json:"AttackLevel,omitempty"` + Suggestion string `json:"Suggestion,omitempty"` Result []Result `json:"Result,omitempty"` Advice []Advice `json:"Advice,omitempty"` Detail []Detail `json:"Detail,omitempty"` } +type Ext struct { + Desensitization string `json:"Desensitization,omitempty"` + SensitiveData []string `json:"SensitiveData,omitempty"` +} + type Result struct { RiskWords string `json:"RiskWords,omitempty"` Description string `json:"Description,omitempty"` Confidence float64 `json:"Confidence,omitempty"` Label string `json:"Label,omitempty"` + Ext Ext `json:"Ext,omitempty"` } type Advice struct { @@ -108,9 +118,10 @@ type Advice struct { } type Detail struct { - Suggestion string `json:"Suggestion,omitempty"` - Type string `json:"Type,omitempty"` - Level string `json:"Level,omitempty"` + Suggestion string `json:"Suggestion,omitempty"` + Type string `json:"Type,omitempty"` + Level string `json:"Level,omitempty"` + Result []Result `json:"Result,omitempty"` } type Matcher struct { @@ -157,6 +168,7 @@ type AISecurityConfig struct { SensitiveDataLevelBar string MaliciousUrlLevelBar string ModelHallucinationLevelBar string + CustomLabelLevelBar string Timeout uint32 BufferLimit int Metrics map[string]proxywasm.MetricCounter @@ -167,6 +179,15 @@ type AISecurityConfig struct { ApiType string // openai, qwen, comfyui, etc. ProviderType string + // "block" or "mask", default "block" + RiskAction string + // Dimension-level action fields (optional, empty string means not configured) + ContentModerationAction string + PromptAttackAction string + SensitiveDataAction string + MaliciousUrlAction string + ModelHallucinationAction string + CustomLabelAction string } func (config *AISecurityConfig) Parse(json gjson.Result) error { @@ -191,6 +212,48 @@ func (config *AISecurityConfig) Parse(json gjson.Result) error { } // set default values config.SetDefaultValues() + // set riskAction + if obj := json.Get("riskAction"); obj.Exists() { + config.RiskAction = obj.String() + if config.RiskAction != "block" && config.RiskAction != "mask" { + return errors.New("invalid riskAction, value must be one of [block, mask]") + } + } + // parse global dimension action fields + isMultiModalGuard := config.Action == MultiModalGuard || config.Action == MultiModalGuardForBase64 + dimensionActionFields := []struct { + fieldName string + target *string + }{ + {"contentModerationAction", &config.ContentModerationAction}, + {"promptAttackAction", &config.PromptAttackAction}, + {"sensitiveDataAction", &config.SensitiveDataAction}, + {"maliciousUrlAction", &config.MaliciousUrlAction}, + {"modelHallucinationAction", &config.ModelHallucinationAction}, + {"customLabelAction", &config.CustomLabelAction}, + } + hasDimensionAction := false + for _, field := range dimensionActionFields { + if isMultiModalGuard { + val, err := parseDimensionAction(json, field.fieldName) + if err != nil { + return err + } + *field.target = val + if val != "" { + hasDimensionAction = true + } + } else { + // Non-MultiModalGuard: read value without validation, field will be ignored at runtime + if obj := json.Get(field.fieldName); obj.Exists() { + *field.target = obj.String() + hasDimensionAction = true + } + } + } + if hasDimensionAction && !isMultiModalGuard { + proxywasm.LogWarnf("dimension action fields are configured but will be ignored because action is %s (not MultiModalGuard/MultiModalGuardForBase64)", config.Action) + } // set values if obj := json.Get("riskLevelBar"); obj.Exists() { config.RiskLevelBar = obj.String() @@ -254,6 +317,12 @@ func (config *AISecurityConfig) Parse(json gjson.Result) error { return errors.New("invalid maliciousUrlLevelBar, value must be one of [max, high, medium, low]") } } + if obj := json.Get("customLabelLevelBar"); obj.Exists() { + config.CustomLabelLevelBar = obj.String() + if LevelToInt(config.CustomLabelLevelBar) <= 0 { + return errors.New("invalid customLabelLevelBar, value must be one of [max, high, medium, low]") + } + } if obj := json.Get("timeout"); obj.Exists() { config.Timeout = uint32(obj.Int()) } @@ -323,6 +392,31 @@ func (config *AISecurityConfig) Parse(json gjson.Result) error { case "regexp": m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))} } + if ra, ok := m["riskAction"]; ok { + raStr := fmt.Sprint(ra) + if raStr != "block" && raStr != "mask" { + return errors.New("invalid riskAction in consumerRiskLevel, value must be one of [block, mask]") + } + } + // Validate dimension action fields in consumer risk level + if isMultiModalGuard { + consumerDimensionActionFields := []string{ + "contentModerationAction", + "promptAttackAction", + "sensitiveDataAction", + "maliciousUrlAction", + "modelHallucinationAction", + "customLabelAction", + } + for _, fieldName := range consumerDimensionActionFields { + if v, ok := m[fieldName]; ok { + vStr := fmt.Sprint(v) + if vStr != "block" && vStr != "mask" { + return fmt.Errorf("invalid %s in consumerRiskLevel, value must be one of [block, mask]", fieldName) + } + } + } + } config.ConsumerRiskLevel = append(config.ConsumerRiskLevel, m) } } @@ -341,6 +435,19 @@ func (config *AISecurityConfig) Parse(json gjson.Result) error { return nil } +// parseDimensionAction parses a dimension action field from JSON config. +// Returns the value if valid (block/mask), empty string if not present, or error if invalid. +func parseDimensionAction(json gjson.Result, fieldName string) (string, error) { + if obj := json.Get(fieldName); obj.Exists() { + val := obj.String() + if val != "block" && val != "mask" { + return "", fmt.Errorf("invalid %s, value must be one of [block, mask]", fieldName) + } + return val, nil + } + return "", nil +} + func (config *AISecurityConfig) SetDefaultValues() { switch config.Action { case TextModerationPlus: @@ -361,10 +468,12 @@ func (config *AISecurityConfig) SetDefaultValues() { config.SensitiveDataLevelBar = S4Sensitive config.ModelHallucinationLevelBar = MaxRisk config.MaliciousUrlLevelBar = MaxRisk + config.CustomLabelLevelBar = MaxRisk config.Timeout = DefaultTimeout config.BufferLimit = 1000 config.ApiType = ApiTextGeneration config.ProviderType = ProviderOpenAI + config.RiskAction = "block" } func (config *AISecurityConfig) IncrementCounter(metricName string, inc uint64) { @@ -436,31 +545,35 @@ func (config *AISecurityConfig) GetResponseImageCheckService(consumer string) st return result } -func (config *AISecurityConfig) GetRiskLevelBar(consumer string) string { - result := config.RiskLevelBar +// getMatchedConsumerRiskRule returns the first matched consumer rule using first-match semantics. +// It iterates ConsumerRiskLevel in order and returns the first rule whose matcher matches the consumer. +// Returns nil, false if no rule matches. +func (config *AISecurityConfig) getMatchedConsumerRiskRule(consumer string) (map[string]interface{}, bool) { for _, obj := range config.ConsumerRiskLevel { if matcher, ok := obj["matcher"].(Matcher); ok { if matcher.match(consumer) { - if riskLevelBar, ok := obj["riskLevelBar"]; ok { - result, _ = riskLevelBar.(string) - } - break + return obj, true } } } + return nil, false +} + +func (config *AISecurityConfig) GetRiskLevelBar(consumer string) string { + result := config.RiskLevelBar + if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok { + if riskLevelBar, ok := rule["riskLevelBar"]; ok { + result, _ = riskLevelBar.(string) + } + } return result } func (config *AISecurityConfig) GetContentModerationLevelBar(consumer string) string { result := config.ContentModerationLevelBar - for _, obj := range config.ConsumerRiskLevel { - if matcher, ok := obj["matcher"].(Matcher); ok { - if matcher.match(consumer) { - if contentModerationLevelBar, ok := obj["contentModerationLevelBar"]; ok { - result, _ = contentModerationLevelBar.(string) - } - break - } + if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok { + if contentModerationLevelBar, ok := rule["contentModerationLevelBar"]; ok { + result, _ = contentModerationLevelBar.(string) } } return result @@ -468,14 +581,9 @@ func (config *AISecurityConfig) GetContentModerationLevelBar(consumer string) st func (config *AISecurityConfig) GetPromptAttackLevelBar(consumer string) string { result := config.PromptAttackLevelBar - for _, obj := range config.ConsumerRiskLevel { - if matcher, ok := obj["matcher"].(Matcher); ok { - if matcher.match(consumer) { - if promptAttackLevelBar, ok := obj["promptAttackLevelBar"]; ok { - result, _ = promptAttackLevelBar.(string) - } - break - } + if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok { + if promptAttackLevelBar, ok := rule["promptAttackLevelBar"]; ok { + result, _ = promptAttackLevelBar.(string) } } return result @@ -483,14 +591,9 @@ func (config *AISecurityConfig) GetPromptAttackLevelBar(consumer string) string func (config *AISecurityConfig) GetSensitiveDataLevelBar(consumer string) string { result := config.SensitiveDataLevelBar - for _, obj := range config.ConsumerRiskLevel { - if matcher, ok := obj["matcher"].(Matcher); ok { - if matcher.match(consumer) { - if sensitiveDataLevelBar, ok := obj["sensitiveDataLevelBar"]; ok { - result, _ = sensitiveDataLevelBar.(string) - } - break - } + if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok { + if sensitiveDataLevelBar, ok := rule["sensitiveDataLevelBar"]; ok { + result, _ = sensitiveDataLevelBar.(string) } } return result @@ -498,14 +601,9 @@ func (config *AISecurityConfig) GetSensitiveDataLevelBar(consumer string) string func (config *AISecurityConfig) GetMaliciousUrlLevelBar(consumer string) string { result := config.MaliciousUrlLevelBar - for _, obj := range config.ConsumerRiskLevel { - if matcher, ok := obj["matcher"].(Matcher); ok { - if matcher.match(consumer) { - if maliciousUrlLevelBar, ok := obj["maliciousUrlLevelBar"]; ok { - result, _ = maliciousUrlLevelBar.(string) - } - break - } + if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok { + if maliciousUrlLevelBar, ok := rule["maliciousUrlLevelBar"]; ok { + result, _ = maliciousUrlLevelBar.(string) } } return result @@ -513,19 +611,124 @@ func (config *AISecurityConfig) GetMaliciousUrlLevelBar(consumer string) string func (config *AISecurityConfig) GetModelHallucinationLevelBar(consumer string) string { result := config.ModelHallucinationLevelBar - for _, obj := range config.ConsumerRiskLevel { - if matcher, ok := obj["matcher"].(Matcher); ok { - if matcher.match(consumer) { - if modelHallucinationLevelBar, ok := obj["modelHallucinationLevelBar"]; ok { - result, _ = modelHallucinationLevelBar.(string) - } - break - } + if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok { + if modelHallucinationLevelBar, ok := rule["modelHallucinationLevelBar"]; ok { + result, _ = modelHallucinationLevelBar.(string) } } return result } +func (config *AISecurityConfig) GetCustomLabelLevelBar(consumer string) string { + result := config.CustomLabelLevelBar + if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok { + if customLabelLevelBar, ok := rule["customLabelLevelBar"]; ok { + result, _ = customLabelLevelBar.(string) + } + } + return result +} + +func (config *AISecurityConfig) GetRiskAction(consumer string) string { + result := config.RiskAction + if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok { + if riskAction, ok := rule["riskAction"]; ok { + result, _ = riskAction.(string) + } + } + return result +} + +// dimensionActionKey maps a detailType to the corresponding key used in consumerRiskLevel map. +// For example, SensitiveDataType -> "sensitiveDataAction". +func dimensionActionKey(detailType string) string { + switch detailType { + case ContentModerationType: + return "contentModerationAction" + case PromptAttackType: + return "promptAttackAction" + case SensitiveDataType: + return "sensitiveDataAction" + case MaliciousUrlDataType: + return "maliciousUrlAction" + case ModelHallucinationDataType: + return "modelHallucinationAction" + case CustomLabelType: + return "customLabelAction" + default: + return "" + } +} + +// getGlobalDimensionAction returns the global dimension action field value for the given detailType. +func (config *AISecurityConfig) getGlobalDimensionAction(detailType string) string { + switch detailType { + case ContentModerationType: + return config.ContentModerationAction + case PromptAttackType: + return config.PromptAttackAction + case SensitiveDataType: + return config.SensitiveDataAction + case MaliciousUrlDataType: + return config.MaliciousUrlAction + case ModelHallucinationDataType: + return config.ModelHallucinationAction + case CustomLabelType: + return config.CustomLabelAction + default: + return "" + } +} + +// enforceMaskBoundary downgrades mask to block for non-sensitiveData dimensions, +// since only sensitiveData supports actual mask/desensitization. +func enforceMaskBoundary(action, detailType, source string) (string, string) { + if action == "mask" && detailType != SensitiveDataType { + proxywasm.LogWarnf("mask action not supported for dimension %s, downgrading to block", detailType) + return "block", source + } + return action, source +} + +// ResolveRiskActionByType resolves the final action for a given dimension type +// using 5-level priority: consumer_dimension > consumer_global > global_dimension > global_global > default(block). +// Returns (action, source) where source indicates which priority level the action came from. +func (config *AISecurityConfig) ResolveRiskActionByType(consumer string, detailType string) (string, string) { + dimKey := dimensionActionKey(detailType) + + // 1. Check matched consumer rule + if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok { + // 1a. consumer dimension action + if dimKey != "" { + if v, exists := rule[dimKey]; exists { + if s, ok := v.(string); ok && s != "" { + return enforceMaskBoundary(s, detailType, "consumer_dimension") + } + } + } + // 1b. consumer global riskAction + if v, exists := rule["riskAction"]; exists { + if s, ok := v.(string); ok && s != "" { + return enforceMaskBoundary(s, detailType, "consumer_global") + } + } + } + + // 2. Global dimension action + globalDimAction := config.getGlobalDimensionAction(detailType) + if globalDimAction != "" { + return enforceMaskBoundary(globalDimAction, detailType, "global_dimension") + } + + // 3. Global riskAction + if config.RiskAction != "" { + return enforceMaskBoundary(config.RiskAction, detailType, "global_global") + } + + // 4. Default block + return "block", "default" +} + func LevelToInt(riskLevel string) int { // First check against our defined constants switch strings.ToLower(riskLevel) { @@ -544,61 +747,157 @@ func LevelToInt(riskLevel string) int { } } -func IsRiskLevelAcceptable(action string, data Data, config AISecurityConfig, consumer string) bool { - if action == MultiModalGuard || action == MultiModalGuardForBase64 { - // Check top-level risk levels for MultiModalGuard - if LevelToInt(data.RiskLevel) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) { - return false - } - // Also check AttackLevel for prompt attack detection - if LevelToInt(data.AttackLevel) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) { - return false - } +type RiskResult int - // Check detailed results for backward compatibility - for _, detail := range data.Detail { - switch detail.Type { - case ContentModerationType: - if LevelToInt(detail.Level) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) { - return false - } - case PromptAttackType: - if LevelToInt(detail.Level) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) { - return false - } - case SensitiveDataType: - if LevelToInt(detail.Level) >= LevelToInt(config.GetSensitiveDataLevelBar(consumer)) { - return false - } - case MaliciousUrlDataType: - if LevelToInt(detail.Level) >= LevelToInt(config.GetMaliciousUrlLevelBar(consumer)) { - return false - } - case ModelHallucinationDataType: - if LevelToInt(detail.Level) >= LevelToInt(config.GetModelHallucinationLevelBar(consumer)) { - return false - } +const ( + RiskPass RiskResult = iota // 放行 + RiskMask // 需要脱敏 + RiskBlock // 需要拦截 +) + +// EvaluateRisk evaluates the risk of the given data and returns a RiskResult. +// For MultiModalGuard/MultiModalGuardForBase64, it uses the unified per-dimension +// action resolution flow (evaluateRiskMultiModal). +// For other actions (e.g. TextModerationPlus), it only checks RiskLevelBar. +func EvaluateRisk(action string, data Data, config AISecurityConfig, consumer string) RiskResult { + if action == MultiModalGuard || action == MultiModalGuardForBase64 { + return evaluateRiskMultiModal(data, config, consumer) + } + // TextModerationPlus and other non-MultiModalGuard actions: dimension actions not used + if LevelToInt(data.RiskLevel) < LevelToInt(config.GetRiskLevelBar(consumer)) { + return RiskPass + } + return RiskBlock +} + +// evaluateRiskMultiModal implements the unified per-dimension risk evaluation for MultiModalGuard. +// It follows the design doc section 11.1-7 pseudocode: +// 1. Top-level compatibility gate (RiskLevel / AttackLevel) +// 2. Per-Detail dimension action resolution and threshold check +// 3. Data.Suggestion=block fallback +func evaluateRiskMultiModal(data Data, config AISecurityConfig, consumer string) RiskResult { + // 1. Top-level compatibility gate + if LevelToInt(data.RiskLevel) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) { + return RiskBlock + } + if LevelToInt(data.AttackLevel) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) { + return RiskBlock + } + + // 2. Detail per-dimension evaluation + hasMask := false + for _, detail := range data.Detail { + dimAction, actionSource := config.ResolveRiskActionByType(consumer, detail.Type) + exceeds := detailExceedsThreshold(detail, config, consumer) + + proxywasm.LogInfof("safecheck_risk_type=%s, safecheck_resolved_action=%s, safecheck_action_source=%s", + detail.Type, dimAction, actionSource) + + if detailTriggersBlock(detail, dimAction, exceeds) { + return RiskBlock + } + // dimAction == "mask" (only sensitiveData effective; others already downgraded by enforceMaskBoundary) + if dimAction == "mask" && detail.Suggestion == "mask" { + if exceeds { + hasMask = true + } else { + proxywasm.LogInfof("safecheck_mask_skipped: type=%s, suggestion=%s, level=%s, threshold=%s", + detail.Type, detail.Suggestion, detail.Level, config.GetSensitiveDataLevelBar(consumer)) } } + } + + // 3. Data.Suggestion=block fallback + if data.Suggestion == "block" { + return RiskBlock + } + + if hasMask { + return RiskMask + } + return RiskPass +} + +// detailTriggersBlock returns whether this single detail should trigger blocking, +// given the resolved dimension action and threshold evaluation result. +func detailTriggersBlock(detail Detail, dimAction string, exceeds bool) bool { + if detail.Suggestion == "block" { return true - } else { - return LevelToInt(data.RiskLevel) < LevelToInt(config.GetRiskLevelBar(consumer)) + } + if dimAction == "block" { + return exceeds + } + // dimAction == "mask": explicit mask suggestion is allowed to pass for desensitization. + if detail.Suggestion == "mask" { + return false + } + return exceeds +} + +// detailExceedsThreshold checks if a single Detail's level exceeds the configured threshold +// for its Type. +func detailExceedsThreshold(detail Detail, config AISecurityConfig, consumer string) bool { + switch detail.Type { + case ContentModerationType: + return LevelToInt(detail.Level) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) + case PromptAttackType: + return LevelToInt(detail.Level) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) + case SensitiveDataType: + return LevelToInt(detail.Level) >= LevelToInt(config.GetSensitiveDataLevelBar(consumer)) + case MaliciousUrlDataType: + return LevelToInt(detail.Level) >= LevelToInt(config.GetMaliciousUrlLevelBar(consumer)) + case ModelHallucinationDataType: + return LevelToInt(detail.Level) >= LevelToInt(config.GetModelHallucinationLevelBar(consumer)) + case CustomLabelType: + return LevelToInt(detail.Level) >= LevelToInt(config.GetCustomLabelLevelBar(consumer)) + default: + return false } } +func IsRiskLevelAcceptable(action string, data Data, config AISecurityConfig, consumer string) bool { + return EvaluateRisk(action, data, config, consumer) != RiskBlock +} + +// ExtractDesensitization extracts the desensitization content from the first Detail +// with Type=sensitiveData and Suggestion=mask. Returns empty string if no such +// Detail exists, if the Detail has no Result entries, or if the desensitization +// content is empty. +func ExtractDesensitization(data Data) string { + for _, detail := range data.Detail { + if detail.Type == SensitiveDataType && detail.Suggestion == "mask" { + if len(detail.Result) > 0 && detail.Result[0].Ext.Desensitization != "" { + return detail.Result[0].Ext.Desensitization + } + } + } + return "" +} + +type BlockedDetail struct { + Type string `json:"type"` + Level string `json:"level"` +} + type DenyResponseBody struct { - BlockedDetails []Detail `json:"blockedDetails"` - RequestId string `json:"requestId"` - // GuardCode is the business code returned by the security service (typically 200 when the check - // succeeded and a risk was detected). It is NOT an HTTP status code. - GuardCode int `json:"guardCode"` + Code int `json:"code"` + DenyMessage string `json:"denyMessage,omitempty"` + BlockedDetails []BlockedDetail `json:"blockedDetails"` } func BuildDenyResponseBody(response Response, config AISecurityConfig, consumer string) ([]byte, error) { + details := GetUnacceptableDetail(response.Data, config, consumer) + blocked := make([]BlockedDetail, 0, len(details)) + for _, d := range details { + blocked = append(blocked, BlockedDetail{ + Type: d.Type, + Level: d.Level, + }) + } body := DenyResponseBody{ - BlockedDetails: GetUnacceptableDetail(response.Data, config, consumer), - RequestId: response.RequestId, - GuardCode: response.Code, + Code: response.Code, + DenyMessage: config.DenyMessage, + BlockedDetails: blocked, } return json.Marshal(body) } @@ -606,27 +905,10 @@ func BuildDenyResponseBody(response Response, config AISecurityConfig, consumer func GetUnacceptableDetail(data Data, config AISecurityConfig, consumer string) []Detail { result := []Detail{} for _, detail := range data.Detail { - switch detail.Type { - case ContentModerationType: - if LevelToInt(detail.Level) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) { - result = append(result, detail) - } - case PromptAttackType: - if LevelToInt(detail.Level) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) { - result = append(result, detail) - } - case SensitiveDataType: - if LevelToInt(detail.Level) >= LevelToInt(config.GetSensitiveDataLevelBar(consumer)) { - result = append(result, detail) - } - case MaliciousUrlDataType: - if LevelToInt(detail.Level) >= LevelToInt(config.GetMaliciousUrlLevelBar(consumer)) { - result = append(result, detail) - } - case ModelHallucinationDataType: - if LevelToInt(detail.Level) >= LevelToInt(config.GetModelHallucinationLevelBar(consumer)) { - result = append(result, detail) - } + dimAction, _ := config.ResolveRiskActionByType(consumer, detail.Type) + exceeds := detailExceedsThreshold(detail, config, consumer) + if detailTriggersBlock(detail, dimAction, exceeds) { + result = append(result, detail) } } // Fallback: when the security service returns a top-level risk signal but no Detail entries, diff --git a/plugins/wasm-go/extensions/ai-security-guard/config/evaluate_risk_property_test.go b/plugins/wasm-go/extensions/ai-security-guard/config/evaluate_risk_property_test.go new file mode 100644 index 000000000..10d3ef98c --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/config/evaluate_risk_property_test.go @@ -0,0 +1,648 @@ +// Copyright (c) 2024 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "math/rand" + "testing" + "testing/quick" +) + +// validSensitiveLevels are the valid sensitive data levels in ascending order. +var validSensitiveLevels = []string{"S0", "S1", "S2", "S3", "S4"} + +// Feature: sensitive-data-mask-threshold, Property 1: Above-threshold mask produces RiskMask +// **Validates: Requirements 1.1, 4.1** +// +// For any valid sensitive level L and threshold T where LevelToInt(L) >= LevelToInt(T), +// when evaluateRiskMultiModal is called with a single Detail of Type=sensitiveData, +// Suggestion=mask, Level=L, config SensitiveDataAction=mask, SensitiveDataLevelBar=T, +// and no other blocking conditions, the result SHALL be RiskMask. +func TestProperty1_AboveThresholdMaskProducesRiskMask(t *testing.T) { + f := func(seed uint64) bool { + // Use seed to deterministically pick a (level, threshold) pair + // where LevelToInt(level) >= LevelToInt(threshold) + r := rand.New(rand.NewSource(int64(seed))) + + // Pick threshold index [0..4], then level index [thresholdIdx..4] + thresholdIdx := r.Intn(len(validSensitiveLevels)) + levelIdx := thresholdIdx + r.Intn(len(validSensitiveLevels)-thresholdIdx) + + level := validSensitiveLevels[levelIdx] + threshold := validSensitiveLevels[thresholdIdx] + + // Sanity: level >= threshold + if LevelToInt(level) < LevelToInt(threshold) { + t.Errorf("generator bug: level=%s (%d) < threshold=%s (%d)", level, LevelToInt(level), threshold, LevelToInt(threshold)) + return false + } + + config := baseConfig() + config.SensitiveDataAction = "mask" + config.SensitiveDataLevelBar = threshold + // Set all other thresholds to max (most permissive) to avoid interference + config.ContentModerationLevelBar = MaxRisk + config.PromptAttackLevelBar = MaxRisk + config.MaliciousUrlLevelBar = MaxRisk + config.ModelHallucinationLevelBar = MaxRisk + config.CustomLabelLevelBar = MaxRisk + config.RiskAction = "block" + + data := Data{ + RiskLevel: "none", // Avoid top-level gate triggering + Detail: []Detail{ + { + Type: SensitiveDataType, + Suggestion: "mask", + Level: level, + Result: []Result{{Ext: Ext{Desensitization: "masked-content"}}}, + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + if result != RiskMask { + t.Errorf("expected RiskMask for level=%s, threshold=%s, got %d", level, threshold, result) + return false + } + return true + } + + cfg := &quick.Config{MaxCount: 200} + if err := quick.Check(f, cfg); err != nil { + t.Errorf("Property 1 failed: %v", err) + fmt.Printf("Property 1 counterexample: %v\n", err) + } +} + +// Feature: sensitive-data-mask-threshold, Property 2: Below-threshold mask produces RiskPass +// **Validates: Requirements 1.2, 1.3** +// +// For any valid sensitive level L and threshold T where LevelToInt(L) < LevelToInt(T), +// when evaluateRiskMultiModal is called with a single Detail of Type=sensitiveData, +// Suggestion=mask, Level=L, config SensitiveDataAction=mask, SensitiveDataLevelBar=T, +// and no other blocking conditions, the result SHALL be RiskPass. +func TestProperty2_BelowThresholdMaskProducesRiskPass(t *testing.T) { + f := func(seed uint64) bool { + // Use seed to deterministically pick a (level, threshold) pair + // where LevelToInt(level) < LevelToInt(threshold) + r := rand.New(rand.NewSource(int64(seed))) + + // Pick threshold index [1..4], then level index [0..thresholdIdx-1] + thresholdIdx := 1 + r.Intn(len(validSensitiveLevels)-1) // [1..4] + levelIdx := r.Intn(thresholdIdx) // [0..thresholdIdx-1] + + level := validSensitiveLevels[levelIdx] + threshold := validSensitiveLevels[thresholdIdx] + + // Sanity: level < threshold + if LevelToInt(level) >= LevelToInt(threshold) { + t.Errorf("generator bug: level=%s (%d) >= threshold=%s (%d)", level, LevelToInt(level), threshold, LevelToInt(threshold)) + return false + } + + config := baseConfig() + config.SensitiveDataAction = "mask" + config.SensitiveDataLevelBar = threshold + config.ContentModerationLevelBar = MaxRisk + config.PromptAttackLevelBar = MaxRisk + config.MaliciousUrlLevelBar = MaxRisk + config.ModelHallucinationLevelBar = MaxRisk + config.CustomLabelLevelBar = MaxRisk + config.RiskAction = "block" + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Type: SensitiveDataType, + Suggestion: "mask", + Level: level, + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + if result != RiskPass { + t.Errorf("expected RiskPass for level=%s, threshold=%s, got %d", level, threshold, result) + return false + } + return true + } + + cfg := &quick.Config{MaxCount: 200} + if err := quick.Check(f, cfg); err != nil { + t.Errorf("Property 2 failed: %v", err) + fmt.Printf("Property 2 counterexample: %v\n", err) + } +} + +// Feature: sensitive-data-mask-threshold, Property 3: Per-detail threshold independence +// **Validates: Requirements 1.4** +// +// For any list of sensitiveData Details each with Suggestion=mask and varying levels, +// and a threshold T, when evaluateRiskMultiModal is called with SensitiveDataAction=mask +// and no blocking conditions: the result SHALL be RiskMask if and only if at least one +// Detail has LevelToInt(Level) >= LevelToInt(T). +func TestProperty3_PerDetailThresholdIndependence(t *testing.T) { + f := func(seed uint64) bool { + r := rand.New(rand.NewSource(int64(seed))) + + // Pick a random threshold from validSensitiveLevels + thresholdIdx := r.Intn(len(validSensitiveLevels)) + threshold := validSensitiveLevels[thresholdIdx] + + // Generate 1-5 random sensitiveData details + numDetails := 1 + r.Intn(5) + details := make([]Detail, numDetails) + expectMask := false + + for i := 0; i < numDetails; i++ { + levelIdx := r.Intn(len(validSensitiveLevels)) + level := validSensitiveLevels[levelIdx] + + detail := Detail{ + Type: SensitiveDataType, + Suggestion: "mask", + Level: level, + } + + // Details that meet threshold should have Result with Desensitization content + if LevelToInt(level) >= LevelToInt(threshold) { + expectMask = true + detail.Result = []Result{{Ext: Ext{Desensitization: "masked-content"}}} + } + + details[i] = detail + } + + config := baseConfig() + config.SensitiveDataAction = "mask" + config.SensitiveDataLevelBar = threshold + // Set all other thresholds to max to avoid interference + config.ContentModerationLevelBar = MaxRisk + config.PromptAttackLevelBar = MaxRisk + config.MaliciousUrlLevelBar = MaxRisk + config.ModelHallucinationLevelBar = MaxRisk + config.CustomLabelLevelBar = MaxRisk + config.RiskAction = "block" + + data := Data{ + RiskLevel: "none", + Detail: details, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + + if expectMask { + if result != RiskMask { + t.Errorf("expected RiskMask: threshold=%s, details=%v, got %d", threshold, describeLevels(details), result) + return false + } + } else { + if result != RiskPass { + t.Errorf("expected RiskPass: threshold=%s, details=%v, got %d", threshold, describeLevels(details), result) + return false + } + } + return true + } + + cfg := &quick.Config{MaxCount: 200} + if err := quick.Check(f, cfg); err != nil { + t.Errorf("Property 3 failed: %v", err) + fmt.Printf("Property 3 counterexample: %v\n", err) + } +} + +// describeLevels returns a slice of level strings from the given details for error reporting. +func describeLevels(details []Detail) []string { + levels := make([]string, len(details)) + for i, d := range details { + levels[i] = d.Level + } + return levels +} + +// validGeneralRiskLevels are the valid general risk levels in ascending order. +var validGeneralRiskLevels = []string{"none", "low", "medium", "high", "max"} + +// knownDetailTypes are the known dimension types used for generating random details. +var knownDetailTypes = []string{ + SensitiveDataType, + ContentModerationType, + PromptAttackType, + MaliciousUrlDataType, + ModelHallucinationDataType, + CustomLabelType, +} + +// Feature: sensitive-data-mask-threshold, Property 4: Block triggers always produce RiskBlock +// **Validates: Requirements 3.1, 3.2** +// +// Sub-property 4a: For any Detail with Suggestion=block, regardless of type, level, +// dimAction, or threshold configuration, evaluateRiskMultiModal SHALL return RiskBlock. +// +// Sub-property 4b: For any Detail where the resolved dimAction is "block" and the +// detail's level exceeds the configured threshold, evaluateRiskMultiModal SHALL return RiskBlock. +func TestProperty4a_SuggestionBlockAlwaysProducesRiskBlock(t *testing.T) { + f := func(seed uint64) bool { + r := rand.New(rand.NewSource(int64(seed))) + + // Pick a random detail type + detailType := knownDetailTypes[r.Intn(len(knownDetailTypes))] + + // Pick a random level based on type + var level string + if detailType == SensitiveDataType { + level = validSensitiveLevels[r.Intn(len(validSensitiveLevels))] + } else { + level = validGeneralRiskLevels[r.Intn(len(validGeneralRiskLevels))] + } + + // Random config: pick random dimAction (block or mask) and random thresholds + config := baseConfig() + + // Randomly assign dimension actions + actions := []string{"block", "mask"} + config.ContentModerationAction = actions[r.Intn(2)] + config.PromptAttackAction = actions[r.Intn(2)] + config.SensitiveDataAction = actions[r.Intn(2)] + config.MaliciousUrlAction = actions[r.Intn(2)] + config.ModelHallucinationAction = actions[r.Intn(2)] + config.CustomLabelAction = actions[r.Intn(2)] + + // Random thresholds + config.ContentModerationLevelBar = validGeneralRiskLevels[1+r.Intn(len(validGeneralRiskLevels)-1)] + config.PromptAttackLevelBar = validGeneralRiskLevels[1+r.Intn(len(validGeneralRiskLevels)-1)] + config.SensitiveDataLevelBar = validSensitiveLevels[r.Intn(len(validSensitiveLevels))] + config.MaliciousUrlLevelBar = validGeneralRiskLevels[1+r.Intn(len(validGeneralRiskLevels)-1)] + config.ModelHallucinationLevelBar = validGeneralRiskLevels[1+r.Intn(len(validGeneralRiskLevels)-1)] + config.CustomLabelLevelBar = validGeneralRiskLevels[1+r.Intn(len(validGeneralRiskLevels)-1)] + + data := Data{ + RiskLevel: "none", // Avoid top-level gate interference + Detail: []Detail{ + { + Type: detailType, + Suggestion: "block", // Always block suggestion + Level: level, + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + if result != RiskBlock { + t.Errorf("expected RiskBlock for Suggestion=block, type=%s, level=%s, got %d", detailType, level, result) + return false + } + return true + } + + cfg := &quick.Config{MaxCount: 200} + if err := quick.Check(f, cfg); err != nil { + t.Errorf("Property 4a failed: %v", err) + fmt.Printf("Property 4a counterexample: %v\n", err) + } +} + +func TestProperty4b_DimActionBlockExceedsThresholdProducesRiskBlock(t *testing.T) { + // Test with dimension types that support block action and have configurable thresholds. + // For each iteration, pick a dimension type, set its action to "block", + // and set level >= threshold to ensure exceeds=true. + type dimConfig struct { + detailType string + levels []string + setThreshold func(config *AISecurityConfig, threshold string) + } + + dims := []dimConfig{ + { + detailType: ContentModerationType, + levels: validGeneralRiskLevels, + setThreshold: func(c *AISecurityConfig, t string) { + c.ContentModerationAction = "block" + c.ContentModerationLevelBar = t + }, + }, + { + detailType: PromptAttackType, + levels: validGeneralRiskLevels, + setThreshold: func(c *AISecurityConfig, t string) { + c.PromptAttackAction = "block" + c.PromptAttackLevelBar = t + }, + }, + { + detailType: SensitiveDataType, + levels: validSensitiveLevels, + setThreshold: func(c *AISecurityConfig, t string) { + c.SensitiveDataAction = "block" + c.SensitiveDataLevelBar = t + }, + }, + { + detailType: MaliciousUrlDataType, + levels: validGeneralRiskLevels, + setThreshold: func(c *AISecurityConfig, t string) { + c.MaliciousUrlAction = "block" + c.MaliciousUrlLevelBar = t + }, + }, + { + detailType: ModelHallucinationDataType, + levels: validGeneralRiskLevels, + setThreshold: func(c *AISecurityConfig, t string) { + c.ModelHallucinationAction = "block" + c.ModelHallucinationLevelBar = t + }, + }, + { + detailType: CustomLabelType, + levels: validGeneralRiskLevels, + setThreshold: func(c *AISecurityConfig, t string) { + c.CustomLabelAction = "block" + c.CustomLabelLevelBar = t + }, + }, + } + + f := func(seed uint64) bool { + r := rand.New(rand.NewSource(int64(seed))) + + // Pick a random dimension + dim := dims[r.Intn(len(dims))] + + // Pick threshold index, then level index >= threshold + thresholdIdx := r.Intn(len(dim.levels)) + levelIdx := thresholdIdx + r.Intn(len(dim.levels)-thresholdIdx) + + threshold := dim.levels[thresholdIdx] + level := dim.levels[levelIdx] + + // Sanity: level >= threshold + if LevelToInt(level) < LevelToInt(threshold) { + t.Errorf("generator bug: level=%s (%d) < threshold=%s (%d)", level, LevelToInt(level), threshold, LevelToInt(threshold)) + return false + } + + config := baseConfig() + // Set all other thresholds to max to avoid interference + config.ContentModerationLevelBar = MaxRisk + config.PromptAttackLevelBar = MaxRisk + config.SensitiveDataLevelBar = S4Sensitive + config.MaliciousUrlLevelBar = MaxRisk + config.ModelHallucinationLevelBar = MaxRisk + config.CustomLabelLevelBar = MaxRisk + + // Configure the chosen dimension with block action and threshold + dim.setThreshold(&config, threshold) + + // Use a non-block suggestion so we test the dimAction=block + exceeds path + // (not the Suggestion=block shortcut tested in 4a) + suggestion := "pass" + + data := Data{ + RiskLevel: "none", // Avoid top-level gate interference + Detail: []Detail{ + { + Type: dim.detailType, + Suggestion: suggestion, + Level: level, + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + if result != RiskBlock { + t.Errorf("expected RiskBlock for dimAction=block, type=%s, level=%s, threshold=%s, got %d", + dim.detailType, level, threshold, result) + return false + } + return true + } + + cfg := &quick.Config{MaxCount: 200} + if err := quick.Check(f, cfg); err != nil { + t.Errorf("Property 4b failed: %v", err) + fmt.Printf("Property 4b counterexample: %v\n", err) + } +} + +// Feature: sensitive-data-mask-threshold, Property 5: Top-level gates produce RiskBlock +// **Validates: Requirements 3.3, 3.4** +// +// Sub-property 5a: For any Data.RiskLevel and contentModerationLevelBar where +// LevelToInt(RiskLevel) >= LevelToInt(contentModerationLevelBar), +// evaluateRiskMultiModal SHALL return RiskBlock regardless of Detail content. +// +// Sub-property 5b: For any Data.AttackLevel and promptAttackLevelBar where +// LevelToInt(AttackLevel) >= LevelToInt(promptAttackLevelBar), +// evaluateRiskMultiModal SHALL return RiskBlock regardless of Detail content. +func TestProperty5a_TopLevelRiskLevelGateProducesRiskBlock(t *testing.T) { + f := func(seed uint64) bool { + r := rand.New(rand.NewSource(int64(seed))) + + // Pick (riskLevel, threshold) where LevelToInt(riskLevel) >= LevelToInt(threshold) + // Use validGeneralRiskLevels [none, low, medium, high, max] + thresholdIdx := r.Intn(len(validGeneralRiskLevels)) + levelIdx := thresholdIdx + r.Intn(len(validGeneralRiskLevels)-thresholdIdx) + + riskLevel := validGeneralRiskLevels[levelIdx] + threshold := validGeneralRiskLevels[thresholdIdx] + + // Sanity check + if LevelToInt(riskLevel) < LevelToInt(threshold) { + t.Errorf("generator bug: riskLevel=%s (%d) < threshold=%s (%d)", + riskLevel, LevelToInt(riskLevel), threshold, LevelToInt(threshold)) + return false + } + + config := baseConfig() + config.ContentModerationLevelBar = threshold + // Set promptAttackLevelBar to max so it doesn't interfere + config.PromptAttackLevelBar = MaxRisk + + // Generate random details to show they don't matter + numDetails := r.Intn(4) // 0-3 random details + details := make([]Detail, numDetails) + for i := 0; i < numDetails; i++ { + detailType := knownDetailTypes[r.Intn(len(knownDetailTypes))] + var level string + if detailType == SensitiveDataType { + level = validSensitiveLevels[r.Intn(len(validSensitiveLevels))] + } else { + level = validGeneralRiskLevels[r.Intn(len(validGeneralRiskLevels))] + } + details[i] = Detail{ + Type: detailType, + Suggestion: "pass", + Level: level, + } + } + + data := Data{ + RiskLevel: riskLevel, + Detail: details, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + if result != RiskBlock { + t.Errorf("expected RiskBlock for RiskLevel=%s, contentModerationLevelBar=%s, got %d", + riskLevel, threshold, result) + return false + } + return true + } + + cfg := &quick.Config{MaxCount: 200} + if err := quick.Check(f, cfg); err != nil { + t.Errorf("Property 5a failed: %v", err) + fmt.Printf("Property 5a counterexample: %v\n", err) + } +} + +func TestProperty5b_TopLevelAttackLevelGateProducesRiskBlock(t *testing.T) { + f := func(seed uint64) bool { + r := rand.New(rand.NewSource(int64(seed))) + + // Pick (attackLevel, threshold) where LevelToInt(attackLevel) >= LevelToInt(threshold) + thresholdIdx := r.Intn(len(validGeneralRiskLevels)) + levelIdx := thresholdIdx + r.Intn(len(validGeneralRiskLevels)-thresholdIdx) + + attackLevel := validGeneralRiskLevels[levelIdx] + threshold := validGeneralRiskLevels[thresholdIdx] + + // Sanity check + if LevelToInt(attackLevel) < LevelToInt(threshold) { + t.Errorf("generator bug: attackLevel=%s (%d) < threshold=%s (%d)", + attackLevel, LevelToInt(attackLevel), threshold, LevelToInt(threshold)) + return false + } + + config := baseConfig() + config.PromptAttackLevelBar = threshold + // Set contentModerationLevelBar to max so it doesn't interfere + config.ContentModerationLevelBar = MaxRisk + + // Generate random details to show they don't matter + numDetails := r.Intn(4) // 0-3 random details + details := make([]Detail, numDetails) + for i := 0; i < numDetails; i++ { + detailType := knownDetailTypes[r.Intn(len(knownDetailTypes))] + var level string + if detailType == SensitiveDataType { + level = validSensitiveLevels[r.Intn(len(validSensitiveLevels))] + } else { + level = validGeneralRiskLevels[r.Intn(len(validGeneralRiskLevels))] + } + details[i] = Detail{ + Type: detailType, + Suggestion: "pass", + Level: level, + } + } + + data := Data{ + AttackLevel: attackLevel, + RiskLevel: "none", // Avoid contentModeration gate interference + Detail: details, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + if result != RiskBlock { + t.Errorf("expected RiskBlock for AttackLevel=%s, promptAttackLevelBar=%s, got %d", + attackLevel, threshold, result) + return false + } + return true + } + + cfg := &quick.Config{MaxCount: 200} + if err := quick.Check(f, cfg); err != nil { + t.Errorf("Property 5b failed: %v", err) + fmt.Printf("Property 5b counterexample: %v\n", err) + } +} + +// Feature: sensitive-data-mask-threshold, Property 6: Data.Suggestion=block fallback +// **Validates: Requirements 3.5** +// +// For any set of Details that do not individually trigger block, when Data.Suggestion=block, +// evaluateRiskMultiModal SHALL return RiskBlock. +func TestProperty6_DataSuggestionBlockFallbackProducesRiskBlock(t *testing.T) { + f := func(seed uint64) bool { + r := rand.New(rand.NewSource(int64(seed))) + + // Generate 0-4 random non-blocking details. + // Strategy: use Suggestion="pass" or "watch" with levels below their thresholds + // so that no detail individually triggers block. + numDetails := r.Intn(5) // 0-4 details + nonBlockSuggestions := []string{"pass", "watch"} + details := make([]Detail, numDetails) + + for i := 0; i < numDetails; i++ { + detailType := knownDetailTypes[r.Intn(len(knownDetailTypes))] + suggestion := nonBlockSuggestions[r.Intn(len(nonBlockSuggestions))] + + // Use "none" level (0) which is always below any meaningful threshold + // since all thresholds are set to max. + var level string + if detailType == SensitiveDataType { + level = "S0" + } else { + level = "none" + } + + details[i] = Detail{ + Type: detailType, + Suggestion: suggestion, + Level: level, + } + } + + config := baseConfig() + // Set all thresholds to max so no detail exceeds threshold + config.ContentModerationLevelBar = MaxRisk + config.PromptAttackLevelBar = MaxRisk + config.SensitiveDataLevelBar = S4Sensitive + config.MaliciousUrlLevelBar = MaxRisk + config.ModelHallucinationLevelBar = MaxRisk + config.CustomLabelLevelBar = MaxRisk + config.RiskAction = "block" + + data := Data{ + RiskLevel: "none", // Avoid top-level RiskLevel gate + AttackLevel: "", // Avoid top-level AttackLevel gate + Suggestion: "block", // The fallback that should trigger RiskBlock + Detail: details, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + if result != RiskBlock { + t.Errorf("expected RiskBlock for Data.Suggestion=block with %d non-blocking details, got %d", + numDetails, result) + return false + } + return true + } + + cfg := &quick.Config{MaxCount: 200} + if err := quick.Check(f, cfg); err != nil { + t.Errorf("Property 6 failed: %v", err) + fmt.Printf("Property 6 counterexample: %v\n", err) + } +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/config/evaluate_risk_test.go b/plugins/wasm-go/extensions/ai-security-guard/config/evaluate_risk_test.go new file mode 100644 index 000000000..7f2aa2723 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/config/evaluate_risk_test.go @@ -0,0 +1,1109 @@ +// Copyright (c) 2024 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// baseConfig returns a config with all thresholds set to max (most permissive) +// so that tests can focus on specific dimension action behavior. +func baseConfig() AISecurityConfig { + return AISecurityConfig{ + RiskAction: "block", + ContentModerationLevelBar: MaxRisk, + PromptAttackLevelBar: MaxRisk, + SensitiveDataLevelBar: S4Sensitive, + MaliciousUrlLevelBar: MaxRisk, + ModelHallucinationLevelBar: MaxRisk, + CustomLabelLevelBar: MaxRisk, + } +} + +// ============================================================================= +// TC-EVAL: 风险判定核心测试(EvaluateRisk) +// ============================================================================= + +// TestTC_EVAL_001 MultiModalGuard,sensitiveDataAction=mask,Suggestion=mask,无 block => RiskMask +func TestTC_EVAL_001(t *testing.T) { + config := baseConfig() + config.SensitiveDataAction = "mask" + config.SensitiveDataLevelBar = "S2" // Lower threshold to match detail Level=S2 + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S2", + Result: []Result{{Ext: Ext{Desensitization: "masked-text"}}}, + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskMask, result) +} + +// TestTC_EVAL_002 同上但 Suggestion=block => RiskBlock +func TestTC_EVAL_002(t *testing.T) { + config := baseConfig() + config.SensitiveDataAction = "mask" + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "block", + Type: SensitiveDataType, + Level: "S2", + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskBlock, result) +} + +// TestTC_EVAL_003 promptAttackAction=block 且该维度超阈值 => RiskBlock +func TestTC_EVAL_003(t *testing.T) { + config := baseConfig() + config.PromptAttackAction = "block" + config.PromptAttackLevelBar = "high" // threshold = high + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "pass", + Type: PromptAttackType, + Level: "high", // level >= threshold + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskBlock, result) +} + +// TestTC_EVAL_004 同时存在 sensitiveData(mask) 与 promptAttack(block) 命中 => RiskBlock +func TestTC_EVAL_004(t *testing.T) { + config := baseConfig() + config.SensitiveDataAction = "mask" + config.PromptAttackAction = "block" + config.PromptAttackLevelBar = "high" + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S2", + Result: []Result{{Ext: Ext{Desensitization: "masked"}}}, + }, + { + Suggestion: "pass", + Type: PromptAttackType, + Level: "high", // exceeds threshold + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskBlock, result) +} + +// TestTC_EVAL_005 仅有 mask 候选且无 block => RiskMask +func TestTC_EVAL_005(t *testing.T) { + config := baseConfig() + config.SensitiveDataAction = "mask" + config.SensitiveDataLevelBar = "S1" // Lower threshold to match detail Level=S1 + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S1", + Result: []Result{{Ext: Ext{Desensitization: "masked"}}}, + }, + { + Suggestion: "pass", + Type: ContentModerationType, + Level: "low", + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskMask, result) +} + +// TestTC_EVAL_006 各维度均不超阈值且无建议 => RiskPass +func TestTC_EVAL_006(t *testing.T) { + config := baseConfig() + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "pass", + Type: ContentModerationType, + Level: "low", + }, + { + Suggestion: "pass", + Type: PromptAttackType, + Level: "low", + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskPass, result) +} + +// TestTC_EVAL_007 顶层 RiskLevel 超 contentModerationLevelBar => RiskBlock +func TestTC_EVAL_007(t *testing.T) { + config := baseConfig() + config.ContentModerationLevelBar = "high" + + data := Data{ + RiskLevel: "high", // >= threshold + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskBlock, result) +} + +// TestTC_EVAL_008 顶层 AttackLevel 超 promptAttackLevelBar => RiskBlock +func TestTC_EVAL_008(t *testing.T) { + config := baseConfig() + config.PromptAttackLevelBar = "high" + + data := Data{ + RiskLevel: "none", + AttackLevel: "high", // >= threshold + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskBlock, result) +} + +// TestTC_EVAL_009 未知 detail.Type 且 Suggestion=pass/watch => 不触发 block/mask +func TestTC_EVAL_009(t *testing.T) { + config := baseConfig() + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "pass", + Type: "unknownType", + Level: "high", + }, + { + Suggestion: "watch", + Type: "anotherUnknown", + Level: "medium", + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskPass, result) +} + +// TestTC_EVAL_010 TextModerationPlus 下配置动作字段 => 仅按 RiskLevelBar 决策 +func TestTC_EVAL_010(t *testing.T) { + config := baseConfig() + config.RiskLevelBar = "high" + config.SensitiveDataAction = "mask" + + // RiskLevel=low < threshold=high => RiskPass + data := Data{ + RiskLevel: "low", + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S2", + Result: []Result{{Ext: Ext{Desensitization: "masked"}}}, + }, + }, + } + + result := EvaluateRisk(TextModerationPlus, data, config, "") + require.Equal(t, RiskPass, result) + + // RiskLevel=high >= threshold=high => RiskBlock + data2 := Data{ + RiskLevel: "high", + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S2", + }, + }, + } + + result2 := EvaluateRisk(TextModerationPlus, data2, config, "") + require.Equal(t, RiskBlock, result2) +} + +// TestTC_EVAL_011 contentModerationAction=mask,但顶层 RiskLevel 超阈值 => RiskBlock +func TestTC_EVAL_011(t *testing.T) { + config := baseConfig() + config.ContentModerationAction = "mask" + config.ContentModerationLevelBar = "high" + + data := Data{ + RiskLevel: "high", // >= threshold => 顶层门控触发 + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskBlock, result) +} + +// TestTC_EVAL_012 promptAttackAction=mask,但顶层 AttackLevel 超阈值 => RiskBlock +func TestTC_EVAL_012(t *testing.T) { + config := baseConfig() + config.PromptAttackAction = "mask" + config.PromptAttackLevelBar = "high" + + data := Data{ + RiskLevel: "none", + AttackLevel: "high", // >= threshold => 顶层门控触发 + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskBlock, result) +} + +// TestTC_EVAL_013 顶层未超阈值,Detail(sensitiveData) Suggestion=mask 且 action=mask => RiskMask +func TestTC_EVAL_013(t *testing.T) { + config := baseConfig() + config.SensitiveDataAction = "mask" + config.SensitiveDataLevelBar = "S1" // Lower threshold to match detail Level=S1 + config.ContentModerationLevelBar = "high" + config.PromptAttackLevelBar = "high" + + data := Data{ + RiskLevel: "low", // < threshold + AttackLevel: "none", // < threshold + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S1", + Result: []Result{{Ext: Ext{Desensitization: "masked-content"}}}, + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskMask, result) +} + +// TestTC_EVAL_014 未知维度 Detail.Type=maliciousFile 且 Suggestion=block => RiskBlock +func TestTC_EVAL_014(t *testing.T) { + config := baseConfig() + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "block", + Type: MaliciousFileType, + Level: "high", + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskBlock, result) +} + +// TestTC_EVAL_015 Detail 不触发拦截,但 Data.Suggestion=block => RiskBlock +func TestTC_EVAL_015(t *testing.T) { + config := baseConfig() + + data := Data{ + RiskLevel: "none", + Suggestion: "block", // 兜底 + Detail: []Detail{ + { + Suggestion: "pass", + Type: ContentModerationType, + Level: "low", + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskBlock, result) +} + +// TestTC_EVAL_016 Data.Suggestion=mask 但无 sensitiveData 脱敏明细 => 不返回 RiskMask +func TestTC_EVAL_016(t *testing.T) { + config := baseConfig() + + data := Data{ + RiskLevel: "none", + Suggestion: "mask", // Data 级别的 mask,但无 sensitiveData 明细 + Detail: []Detail{ + { + Suggestion: "pass", + Type: ContentModerationType, + Level: "low", + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskPass, result) +} + +// TestTC_EVAL_017 contentModerationAction=mask 且 Detail(contentModeration).Suggestion=mask +// => 不返回 RiskMask(降级为 block 语义) +func TestTC_EVAL_017(t *testing.T) { + config := baseConfig() + config.ContentModerationAction = "mask" + config.ContentModerationLevelBar = "high" + + // contentModeration 维度配置 mask,但 enforceMaskBoundary 会降级为 block + // Detail level=low < threshold=high => 不超阈值 => 不触发 block + // Suggestion=mask 对非 sensitiveData 维度不产生 RiskMask + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "mask", + Type: ContentModerationType, + Level: "low", + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + // contentModeration 的 mask 被降级为 block,level 未超阈值 => 不触发 block + // Suggestion=mask 但 dimAction 已降级为 block => 不进入 mask 分支 + // 最终 RiskPass + require.Equal(t, RiskPass, result) +} + +// ============================================================================= +// TC-DESENS: 脱敏提取测试(ExtractDesensitization) +// ============================================================================= + +// TestTC_DESENS_001 sensitiveData + Suggestion=mask + Ext.Desensitization => 返回脱敏文本 +func TestTC_DESENS_001(t *testing.T) { + data := Data{ + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S2", + Result: []Result{ + {Ext: Ext{Desensitization: "我的电话是1**********"}}, + }, + }, + }, + } + + result := ExtractDesensitization(data) + require.Equal(t, "我的电话是1**********", result) +} + +// TestTC_DESENS_002 非 sensitiveData 且 Suggestion=mask => 忽略 +func TestTC_DESENS_002(t *testing.T) { + data := Data{ + Detail: []Detail{ + { + Suggestion: "mask", + Type: ContentModerationType, + Level: "high", + Result: []Result{ + {Ext: Ext{Desensitization: "some-content"}}, + }, + }, + }, + } + + result := ExtractDesensitization(data) + require.Equal(t, "", result) +} + +// TestTC_DESENS_003 多条 sensitiveData 明细,首条无脱敏、次条有脱敏 => 返回次条 +func TestTC_DESENS_003(t *testing.T) { + data := Data{ + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S2", + Result: []Result{ + {Ext: Ext{Desensitization: ""}}, // 首条无脱敏内容 + }, + }, + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S3", + Result: []Result{ + {Ext: Ext{Desensitization: "脱敏后的内容"}}, + }, + }, + }, + } + + result := ExtractDesensitization(data) + require.Equal(t, "脱敏后的内容", result) +} + +// TestTC_DESENS_004 无任何可用脱敏文本 => 返回空字符串 +func TestTC_DESENS_004(t *testing.T) { + data := Data{ + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S2", + Result: []Result{ + {Ext: Ext{Desensitization: ""}}, + }, + }, + { + Suggestion: "pass", + Type: SensitiveDataType, + Level: "S1", + Result: []Result{ + {Ext: Ext{Desensitization: "some-text"}}, + }, + }, + }, + } + + result := ExtractDesensitization(data) + require.Equal(t, "", result) +} + +// ============================================================================= +// 补充边界测试 +// ============================================================================= + +// TestTC_EVAL_018 MultiModalGuardForBase64 路径走统一判定流程 +func TestTC_EVAL_018(t *testing.T) { + config := baseConfig() + config.SensitiveDataAction = "mask" + config.SensitiveDataLevelBar = "S2" // Lower threshold to match detail Level=S2 + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S2", + Result: []Result{{Ext: Ext{Desensitization: "masked-text"}}}, + }, + }, + } + + result := EvaluateRisk(MultiModalGuardForBase64, data, config, "") + require.Equal(t, RiskMask, result) + + // block 场景 + data2 := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "block", + Type: ContentModerationType, + Level: "high", + }, + }, + } + result2 := EvaluateRisk(MultiModalGuardForBase64, data2, config, "") + require.Equal(t, RiskBlock, result2) +} + +// TestTC_EVAL_019 空 Detail 列表 + Data.Suggestion=block => RiskBlock +func TestTC_EVAL_019(t *testing.T) { + config := baseConfig() + + data := Data{ + RiskLevel: "none", + Suggestion: "block", + Detail: []Detail{}, // 空 Detail 列表 + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskBlock, result) +} + +// TestTC_EVAL_020 空 Detail 列表 + 无 Data.Suggestion => RiskPass +func TestTC_EVAL_020(t *testing.T) { + config := baseConfig() + + data := Data{ + RiskLevel: "none", + Detail: []Detail{}, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskPass, result) +} + +// TestTC_EVAL_021 多维度混合:sensitiveData(mask) + contentModeration(pass) + promptAttack(block 超阈值) +// => RiskBlock(promptAttack 超阈值触发拦截) +func TestTC_EVAL_021(t *testing.T) { + config := baseConfig() + config.SensitiveDataAction = "mask" + config.PromptAttackAction = "block" + config.PromptAttackLevelBar = "high" + config.ContentModerationLevelBar = "high" + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S2", + Result: []Result{{Ext: Ext{Desensitization: "masked"}}}, + }, + { + Suggestion: "pass", + Type: ContentModerationType, + Level: "low", + }, + { + Suggestion: "pass", + Type: PromptAttackType, + Level: "high", // 超阈值 + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskBlock, result) +} + +// TestTC_EVAL_022 多维度混合:sensitiveData(mask) + contentModeration(block 未超阈值) + promptAttack(block 未超阈值) +// => RiskMask(无 block 触发,有 mask 候选) +func TestTC_EVAL_022(t *testing.T) { + config := baseConfig() + config.SensitiveDataAction = "mask" + config.SensitiveDataLevelBar = "S2" // Lower threshold to match detail Level=S2 + config.ContentModerationAction = "block" + config.ContentModerationLevelBar = "high" + config.PromptAttackAction = "block" + config.PromptAttackLevelBar = "high" + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S2", + Result: []Result{{Ext: Ext{Desensitization: "masked"}}}, + }, + { + Suggestion: "pass", + Type: ContentModerationType, + Level: "low", // 未超阈值 + }, + { + Suggestion: "pass", + Type: PromptAttackType, + Level: "low", // 未超阈值 + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskMask, result) +} + +// TestTC_EVAL_023 未知维度 Type + Suggestion=pass + 高 level => RiskPass +// (detailExceedsThreshold 对未知 Type 返回 false) +func TestTC_EVAL_023(t *testing.T) { + config := baseConfig() + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "pass", + Type: MaliciousFileType, // 未知维度(不在 dimensionActionKey 映射中) + Level: "max", // 即使 level 很高 + }, + { + Suggestion: "pass", + Type: WaterMarkType, // 另一个未知维度 + Level: "max", + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskPass, result) +} + +// TestTC_EVAL_024 sensitiveDataAction=mask 但 Suggestion=pass 且 level 超阈值 => RiskBlock +func TestTC_EVAL_024(t *testing.T) { + config := baseConfig() + config.SensitiveDataAction = "mask" + config.SensitiveDataLevelBar = "S2" + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "pass", + Type: SensitiveDataType, + Level: "S3", // 超阈值 + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskBlock, result) +} + +// TestTC_EVAL_025 sensitiveDataAction=mask 但 Suggestion=pass 且 level 未超阈值 => RiskPass +func TestTC_EVAL_025(t *testing.T) { + config := baseConfig() + config.SensitiveDataAction = "mask" + config.SensitiveDataLevelBar = "S4" + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "pass", + Type: SensitiveDataType, + Level: "S1", // 未超阈值 + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskPass, result) +} + +// TestTC_EVAL_026 Data.RiskLevel 为空字符串 => 不触发顶层门控 +func TestTC_EVAL_026(t *testing.T) { + config := baseConfig() + config.ContentModerationLevelBar = "high" + + data := Data{ + RiskLevel: "", // 空字符串 + Detail: []Detail{ + { + Suggestion: "pass", + Type: ContentModerationType, + Level: "low", + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskPass, result) +} + +// TestTC_EVAL_027 consumer 维度动作集成:consumer sensitiveDataAction=mask + riskAction=block +// => sensitiveData 走 mask,promptAttack 走 block +func TestTC_EVAL_027(t *testing.T) { + config := baseConfig() + config.PromptAttackLevelBar = "high" + config.ConsumerRiskLevel = []map[string]interface{}{ + { + "matcher": Matcher{Exact: "user-a"}, + "riskAction": "block", + "sensitiveDataAction": "mask", + "sensitiveDataLevelBar": "S2", // Lower threshold to match detail Level=S2 + }, + } + + // sensitiveData mask + promptAttack 未超阈值 => RiskMask + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S2", + Result: []Result{{Ext: Ext{Desensitization: "masked"}}}, + }, + { + Suggestion: "pass", + Type: PromptAttackType, + Level: "low", // 未超阈值 + }, + }, + } + result := EvaluateRisk(MultiModalGuard, data, config, "user-a") + require.Equal(t, RiskMask, result) + + // promptAttack 超阈值 => RiskBlock(即使有 mask 候选) + data2 := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S2", + Result: []Result{{Ext: Ext{Desensitization: "masked"}}}, + }, + { + Suggestion: "pass", + Type: PromptAttackType, + Level: "high", // 超阈值 + }, + }, + } + result2 := EvaluateRisk(MultiModalGuard, data2, config, "user-a") + require.Equal(t, RiskBlock, result2) +} + +// TestTC_EVAL_028 Data.Suggestion=block 兜底 + 有 mask 候选 => RiskBlock +// block 兜底优先于 mask 候选 +func TestTC_EVAL_028(t *testing.T) { + config := baseConfig() + config.SensitiveDataAction = "mask" + + data := Data{ + RiskLevel: "none", + Suggestion: "block", // 兜底 block + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S1", + Result: []Result{{Ext: Ext{Desensitization: "masked"}}}, + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskBlock, result) +} + +// TestTC_DESENS_005 Detail.Result 为空数组 => 返回空字符串 +func TestTC_DESENS_005(t *testing.T) { + data := Data{ + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S2", + Result: []Result{}, // 空数组 + }, + }, + } + + result := ExtractDesensitization(data) + require.Equal(t, "", result) +} + +// TestTC_EVAL_029 未命中 consumer 规则时回退全局维度动作 +func TestTC_EVAL_029(t *testing.T) { + config := baseConfig() + config.SensitiveDataAction = "mask" + config.SensitiveDataLevelBar = "S1" // Lower threshold to match detail Level=S1 + config.ConsumerRiskLevel = []map[string]interface{}{ + { + "matcher": Matcher{Exact: "vip-user"}, + "riskAction": "block", + }, + } + + // "other-user" 不匹配任何规则,回退到 global_dimension(mask) + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S1", + Result: []Result{{Ext: Ext{Desensitization: "masked"}}}, + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "other-user") + require.Equal(t, RiskMask, result) +} + +// ============================================================================= +// TC-EVAL: detailExceedsThreshold 各维度覆盖 +// ============================================================================= + +// TestTC_EVAL_030 MaliciousUrlDataType 超阈值 => RiskBlock +func TestTC_EVAL_030(t *testing.T) { + config := baseConfig() + config.MaliciousUrlLevelBar = "medium" + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "pass", + Type: MaliciousUrlDataType, + Level: "high", // exceeds "medium" + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskBlock, result) +} + +// TestTC_EVAL_031 ModelHallucinationDataType 超阈值 => RiskBlock +func TestTC_EVAL_031(t *testing.T) { + config := baseConfig() + config.ModelHallucinationLevelBar = "medium" + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "none", + Type: ModelHallucinationDataType, + Level: "high", // exceeds "medium" + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskBlock, result) +} + +// TestTC_EVAL_032 CustomLabelType 超阈值 => RiskBlock +func TestTC_EVAL_032(t *testing.T) { + config := baseConfig() + config.CustomLabelLevelBar = "low" + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "none", + Type: CustomLabelType, + Level: "medium", // exceeds "low" + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskBlock, result) +} + +// TestTC_EVAL_033 MaliciousUrlDataType 未超阈值 => RiskPass +func TestTC_EVAL_033(t *testing.T) { + config := baseConfig() + config.MaliciousUrlLevelBar = "high" + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "pass", + Type: MaliciousUrlDataType, + Level: "low", // below "high" + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskPass, result) +} + +// TestTC_EVAL_034 ModelHallucinationDataType 未超阈值 => RiskPass +func TestTC_EVAL_034(t *testing.T) { + config := baseConfig() + config.ModelHallucinationLevelBar = "high" + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "none", + Type: ModelHallucinationDataType, + Level: "low", // below "high" + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskPass, result) +} + +// TestTC_EVAL_035 CustomLabelType 未超阈值 + 有 mask 候选 => RiskMask +func TestTC_EVAL_035(t *testing.T) { + config := baseConfig() + config.CustomLabelLevelBar = "high" + config.SensitiveDataAction = "mask" + config.SensitiveDataLevelBar = "S1" // Lower threshold to match detail Level=S1 + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "none", + Type: CustomLabelType, + Level: "low", // below "high" + }, + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S1", + Result: []Result{{Ext: Ext{Desensitization: "masked"}}}, + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskMask, result) +} + +// ============================================================================= +// TC-EVAL: 阈值边界测试(Threshold Boundary Tests) +// ============================================================================= + +// TestTC_EVAL_036 低于阈值的 mask 建议 => RiskPass +// Config: sensitiveDataAction=mask, sensitiveDataLevelBar=S3 +// Detail: Type=sensitiveData, Suggestion=mask, Level=S1 (低于 S3) +// Expected: RiskPass(Level 未达阈值,跳过脱敏) +// Validates: Requirements 5.2 +func TestTC_EVAL_036(t *testing.T) { + config := baseConfig() + config.SensitiveDataAction = "mask" + config.SensitiveDataLevelBar = "S3" + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S1", // S1 < S3 => 低于阈值 + Result: []Result{{Ext: Ext{Desensitization: "masked"}}}, + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskPass, result) +} + +// TestTC_EVAL_037 恰好达到阈值的 mask 建议 => RiskMask +// Config: sensitiveDataAction=mask, sensitiveDataLevelBar=S2 +// Detail: Type=sensitiveData, Suggestion=mask, Level=S2 (等于 S2) +// Expected: RiskMask(Level 达到阈值,触发脱敏) +// Validates: Requirements 5.3 +func TestTC_EVAL_037(t *testing.T) { + config := baseConfig() + config.SensitiveDataAction = "mask" + config.SensitiveDataLevelBar = "S2" + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S2", // S2 >= S2 => 达到阈值 + Result: []Result{{Ext: Ext{Desensitization: "masked-text"}}}, + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskMask, result) +} + +// TestTC_EVAL_038 混合高低阈值明细 => RiskMask +// Config: sensitiveDataAction=mask, sensitiveDataLevelBar=S3 +// Details: 一条 Level=S1(低于阈值),一条 Level=S3(达到阈值) +// Expected: RiskMask(达到阈值的明细触发脱敏) +// Validates: Requirements 5.4 +func TestTC_EVAL_038(t *testing.T) { + config := baseConfig() + config.SensitiveDataAction = "mask" + config.SensitiveDataLevelBar = "S3" + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S1", // S1 < S3 => 低于阈值,不贡献 mask + Result: []Result{{Ext: Ext{Desensitization: "masked-low"}}}, + }, + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S3", // S3 >= S3 => 达到阈值,贡献 mask + Result: []Result{{Ext: Ext{Desensitization: "masked-high"}}}, + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskMask, result) +} + +// TestTC_EVAL_039 所有明细均低于阈值 => RiskPass +// Config: sensitiveDataAction=mask, sensitiveDataLevelBar=S4 +// Details: 两条 sensitiveData,Level=S1 和 Level=S2(均低于 S4) +// Expected: RiskPass(无明细达到阈值,全部跳过脱敏) +// Validates: Requirements 5.2, 5.4 +func TestTC_EVAL_039(t *testing.T) { + config := baseConfig() + config.SensitiveDataAction = "mask" + config.SensitiveDataLevelBar = "S4" + + data := Data{ + RiskLevel: "none", + Detail: []Detail{ + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S1", // S1 < S4 => 低于阈值 + Result: []Result{{Ext: Ext{Desensitization: "masked-1"}}}, + }, + { + Suggestion: "mask", + Type: SensitiveDataType, + Level: "S2", // S2 < S4 => 低于阈值 + Result: []Result{{Ext: Ext{Desensitization: "masked-2"}}}, + }, + }, + } + + result := EvaluateRisk(MultiModalGuard, data, config, "") + require.Equal(t, RiskPass, result) +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/go.mod b/plugins/wasm-go/extensions/ai-security-guard/go.mod index 963359750..1e204ef19 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/go.mod +++ b/plugins/wasm-go/extensions/ai-security-guard/go.mod @@ -20,6 +20,6 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/resp v0.1.1 // indirect - github.com/tidwall/sjson v1.2.5 // indirect + github.com/tidwall/sjson v1.2.5 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text/openai.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text/openai.go index 98ef4e5af..25e41be19 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text/openai.go +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text/openai.go @@ -64,9 +64,13 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur } contentIndex := 0 imageIndex := 0 + hasMasked := false + maskedContent := []byte(content) sessionID, _ := utils.GenerateHexID(20) var singleCall func() var singleCallForImage func() + // prevContentIndex tracks the start of the current chunk for masking replacement + prevContentIndex := 0 callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { log.Info(string(responseBody)) if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { @@ -80,11 +84,47 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur proxywasm.ResumeHttpRequest() return } - if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) { - if contentIndex >= len(content) { + riskResult := cfg.EvaluateRisk(config.Action, response.Data, config, consumer) + proxywasm.LogInfof("safecheck_resolved_action=%v", riskResult) + switch riskResult { + case cfg.RiskPass: + if contentIndex >= len(maskedContent) { endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) - ctx.SetUserAttribute("safecheck_status", "request pass") + if hasMasked { + // All chunks processed, some had masking — replace the content in request body + newBody, replaceErr := utils.ReplaceJsonFieldTextContent(body, config.RequestContentJsonPath, string(maskedContent)) + if replaceErr != nil { + log.Errorf("failed to replace request body content, falling back to block: %v", replaceErr) + // Fall back to block to prevent leaking sensitive data + denyMessage := cfg.DefaultDenyMessage + if config.DenyMessage != "" { + denyMessage = config.DenyMessage + } + marshalledDenyMessage := wrapper.MarshalStr(denyMessage) + if config.ProtocolOriginal { + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) + } else if gjson.GetBytes(body, "stream").Bool() { + randomID := utils.GenerateRandomChatID() + jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) + } else { + randomID := utils.GenerateRandomChatID() + jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage)) + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) + } + ctx.DontReadResponseBody() + config.IncrementCounter("ai_sec_request_deny", 1) + ctx.SetUserAttribute("safecheck_status", "reqeust deny") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + return + } + proxywasm.ReplaceHttpRequestBody(newBody) + config.IncrementCounter("ai_sec_request_mask", 1) + ctx.SetUserAttribute("safecheck_status", "request mask") + } else { + ctx.SetUserAttribute("safecheck_status", "request pass") + } ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) if len(images) > 0 && config.CheckRequestImage { singleCallForImage() @@ -95,45 +135,110 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur singleCall() } return + case cfg.RiskMask: + desensitization := cfg.ExtractDesensitization(response.Data) + if desensitization == "" { + proxywasm.LogInfof("safecheck_action_source=mask_fallback_to_block, reason=empty_desensitization") + log.Warnf("desensitization content is empty, falling back to block logic") + } else { + // Replace only the current chunk portion in maskedContent + chunkStart := prevContentIndex + chunkEnd := contentIndex + maskedContent = append(maskedContent[:chunkStart], append([]byte(desensitization), maskedContent[chunkEnd:]...)...) + // Adjust contentIndex for the length difference + lengthDiff := len(desensitization) - (chunkEnd - chunkStart) + contentIndex += lengthDiff + hasMasked = true + // Continue checking remaining chunks + if contentIndex >= len(maskedContent) { + // All chunks done, apply the masked content + newBody, replaceErr := utils.ReplaceJsonFieldTextContent(body, config.RequestContentJsonPath, string(maskedContent)) + if replaceErr != nil { + log.Errorf("failed to replace request body content, falling back to block: %v", replaceErr) + // Fall back to block to prevent leaking sensitive data + denyMessage := cfg.DefaultDenyMessage + if config.DenyMessage != "" { + denyMessage = config.DenyMessage + } + marshalledDenyMessage := wrapper.MarshalStr(denyMessage) + if config.ProtocolOriginal { + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) + } else if gjson.GetBytes(body, "stream").Bool() { + randomID := utils.GenerateRandomChatID() + jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) + } else { + randomID := utils.GenerateRandomChatID() + jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage)) + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) + } + ctx.DontReadResponseBody() + config.IncrementCounter("ai_sec_request_deny", 1) + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "reqeust deny") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + return + } + proxywasm.ReplaceHttpRequestBody(newBody) + config.IncrementCounter("ai_sec_request_mask", 1) + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "request mask") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + if len(images) > 0 && config.CheckRequestImage { + singleCallForImage() + } else { + proxywasm.ResumeHttpRequest() + } + } else { + singleCall() + } + return + } + // Fall through to block logic when desensitization is empty + fallthrough + case cfg.RiskBlock: + denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) + if err != nil { + log.Errorf("failed to build deny response body: %v", err) + proxywasm.ResumeHttpRequest() + return + } + if config.ProtocolOriginal { + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1) + } else if gjson.GetBytes(body, "stream").Bool() { + randomID := utils.GenerateRandomChatID() + marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) + jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) + } else { + randomID := utils.GenerateRandomChatID() + marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) + jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage)) + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) + } + ctx.DontReadResponseBody() + config.IncrementCounter("ai_sec_request_deny", 1) + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "reqeust deny") + if response.Data.Advice != nil { + ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) + ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) + } + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) } - denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) - if err != nil { - log.Errorf("failed to build deny response body: %v", err) - proxywasm.ResumeHttpRequest() - return - } - if config.ProtocolOriginal { - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1) - } else if gjson.GetBytes(body, "stream").Bool() { - randomID := utils.GenerateRandomChatID() - marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) - jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) - } else { - randomID := utils.GenerateRandomChatID() - marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) - jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage)) - proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) - } - ctx.DontReadResponseBody() - config.IncrementCounter("ai_sec_request_deny", 1) - endTime := time.Now().UnixMilli() - ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) - ctx.SetUserAttribute("safecheck_status", "reqeust deny") - if response.Data.Advice != nil { - ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) - ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) - } - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) } singleCall = func() { + prevContentIndex = contentIndex var nextContentIndex int - if contentIndex+cfg.LengthLimit >= len(content) { - nextContentIndex = len(content) + if contentIndex+cfg.LengthLimit >= len(maskedContent) { + nextContentIndex = len(maskedContent) } else { nextContentIndex = contentIndex + cfg.LengthLimit } - contentPiece := content[contentIndex:nextContentIndex] + contentPiece := string(maskedContent[contentIndex:nextContentIndex]) contentIndex = nextContentIndex log.Debugf("current content piece: %s", contentPiece) path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID) diff --git a/plugins/wasm-go/extensions/ai-security-guard/main_test.go b/plugins/wasm-go/extensions/ai-security-guard/main_test.go index 67f55923a..10ead8850 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main_test.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main_test.go @@ -16,13 +16,16 @@ package main import ( "encoding/json" + "strings" "testing" cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/proxytest" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/test" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) // 测试配置:基础安全配置 @@ -134,6 +137,56 @@ var consumerSpecificConfig = func() json.RawMessage { return data }() +// 测试配置:包含 customLabelLevelBar 和消费者级别覆盖 +var customLabelConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "checkResponse": true, + "action": "MultiModalGuard", + "customLabelLevelBar": "high", + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "consumerRiskLevel": []map[string]interface{}{ + { + "name": "exact-user", + "matchType": "exact", + "customLabelLevelBar": "low", + }, + { + "name": "prefix-", + "matchType": "prefix", + "customLabelLevelBar": "medium", + }, + }, + }) + return data +}() + +// 测试配置:脱敏模式配置(riskAction=mask) +var maskConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "checkResponse": false, + "action": "MultiModalGuard", + "riskAction": "mask", + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "sensitiveDataLevelBar": "S3", + "timeout": 2000, + }) + return data +}() + // 测试配置:MCP配置 var mcpConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ @@ -156,7 +209,6 @@ var mcpConfig = func() json.RawMessage { return data }() -// 测试配置:MCP配置(启用请求检查) var mcpRequestConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", @@ -240,6 +292,26 @@ var multiModalGuardImageQwenConfig = func() json.RawMessage { return data }() +// 测试配置:ProtocolOriginal MultiModalGuard +var protocolOriginalConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "checkResponse": false, + "action": "MultiModalGuard", + "protocol": "original", + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "sensitiveDataLevelBar": "S3", + "timeout": 2000, + }) + return data +}() + func TestParseConfig(t *testing.T) { test.RunGoTest(t, func(t *testing.T) { // 测试基础配置解析 @@ -414,51 +486,6 @@ func TestOnHttpRequestBody(t *testing.T) { // 空内容应该直接通过 require.Equal(t, types.ActionContinue, action) }) - - // TextModerationPlus(默认 action,含 agent/OpenAI 形态)请求拦截应返回 choices[0].message.content 内的 blockedDetails JSON - t.Run("text moderation plus request deny returns blockedDetails in openai completion shape", func(t *testing.T) { - host, status := test.NewTestHost(basicConfig) - defer host.Reset() - require.Equal(t, types.OnPluginStartStatusOK, status) - - host.CallOnHttpRequestHeaders([][2]string{ - {":authority", "example.com"}, - {":path", "/v1/chat/completions"}, - {":method", "POST"}, - }) - - body := `{"messages": [{"role": "user", "content": "trigger deny"}]}` - require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) - - securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-tmp-deny", "Data": {"RiskLevel": "high"}}` - host.CallOnHttpCall([][2]string{ - {":status", "200"}, - {"content-type", "application/json"}, - }, []byte(securityResponse)) - - local := host.GetLocalResponse() - require.NotNil(t, local, "expected SendHttpResponse for request deny") - require.Contains(t, string(local.Data), "blockedDetails") - require.Contains(t, string(local.Data), "req-tmp-deny") - - type openAIChatCompletion struct { - Choices []struct { - Message struct { - Content string `json:"content"` - } `json:"message"` - } `json:"choices"` - } - var outer openAIChatCompletion - require.NoError(t, json.Unmarshal(local.Data, &outer)) - require.Len(t, outer.Choices, 1) - - var deny cfg.DenyResponseBody - require.NoError(t, json.Unmarshal([]byte(outer.Choices[0].Message.Content), &deny)) - require.Equal(t, "req-tmp-deny", deny.RequestId) - require.Equal(t, 200, deny.GuardCode) - require.NotEmpty(t, deny.BlockedDetails) - require.Equal(t, cfg.ContentModerationType, deny.BlockedDetails[0].Type) - }) }) } @@ -741,6 +768,189 @@ func TestMCP(t *testing.T) { }) } +func TestGetRiskAction(t *testing.T) { + // 测试全局默认值 + t.Run("default is block", func(t *testing.T) { + config := cfg.AISecurityConfig{} + config.SetDefaultValues() + require.Equal(t, "block", config.GetRiskAction("any-consumer")) + }) + + // 测试全局配置为 mask,无消费者覆盖 + t.Run("global mask without consumer override", func(t *testing.T) { + config := cfg.AISecurityConfig{RiskAction: "mask"} + require.Equal(t, "mask", config.GetRiskAction("any-consumer")) + }) + + // 测试消费者级别覆盖 riskAction + t.Run("consumer overrides riskAction to mask", func(t *testing.T) { + config := cfg.AISecurityConfig{ + RiskAction: "block", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": cfg.Matcher{Exact: "vip-user"}, + "riskAction": "mask", + }, + }, + } + require.Equal(t, "mask", config.GetRiskAction("vip-user")) + require.Equal(t, "block", config.GetRiskAction("normal-user")) + }) + + // 测试消费者匹配但未配置 riskAction,fallback 到全局 + t.Run("consumer matched without riskAction falls back to global", func(t *testing.T) { + config := cfg.AISecurityConfig{ + RiskAction: "mask", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": cfg.Matcher{Exact: "some-user"}, + "contentModerationLevelBar": "low", + }, + }, + } + require.Equal(t, "mask", config.GetRiskAction("some-user")) + }) + + // 测试 prefix 匹配 + t.Run("consumer prefix match", func(t *testing.T) { + config := cfg.AISecurityConfig{ + RiskAction: "block", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": cfg.Matcher{Prefix: "test-"}, + "riskAction": "mask", + }, + }, + } + require.Equal(t, "mask", config.GetRiskAction("test-user-1")) + require.Equal(t, "block", config.GetRiskAction("prod-user")) + }) + + // 测试空 consumer 不匹配任何规则 + t.Run("empty consumer falls back to global", func(t *testing.T) { + config := cfg.AISecurityConfig{ + RiskAction: "mask", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": cfg.Matcher{Exact: "vip"}, + "riskAction": "block", + }, + }, + } + require.Equal(t, "mask", config.GetRiskAction("")) + }) +} + +func TestEvaluateRiskWithConsumerRiskAction(t *testing.T) { + // 需要 proxy-wasm host 环境,因为 evaluateRiskMultiModal 调用 proxywasm.LogInfof + opt := proxytest.NewEmulatorOption().WithVMContext(&types.DefaultVMContext{}) + _, reset := proxytest.NewHostEmulator(opt) + defer reset() + + // 测试全局 block,消费者 mask + t.Run("global block consumer mask", func(t *testing.T) { + config := cfg.AISecurityConfig{ + RiskAction: "block", + ContentModerationLevelBar: "max", + PromptAttackLevelBar: "max", + SensitiveDataLevelBar: "S4", + MaliciousUrlLevelBar: "max", + ModelHallucinationLevelBar: "max", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": cfg.Matcher{Exact: "vip-user"}, + "riskAction": "mask", + "sensitiveDataLevelBar": "S2", + }, + }, + } + data := cfg.Data{ + RiskLevel: "none", + Detail: []cfg.Detail{ + {Suggestion: "mask", Type: "sensitiveData", Level: "S2", + Result: []cfg.Result{{Ext: cfg.Ext{Desensitization: "masked"}}}}, + }, + } + // vip-user 使用 mask 模式,consumer 阈值 S2,Level=S2 >= S2 → RiskMask + require.Equal(t, cfg.RiskMask, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "vip-user")) + // normal-user 使用全局 block 模式,全局阈值 S4,Level=S2 < S4 → RiskPass + require.Equal(t, cfg.RiskPass, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "normal-user")) + }) + + // 测试全局 mask,消费者 block + t.Run("global mask consumer block", func(t *testing.T) { + config := cfg.AISecurityConfig{ + RiskAction: "mask", + ContentModerationLevelBar: "max", + PromptAttackLevelBar: "max", + SensitiveDataLevelBar: "S2", + MaliciousUrlLevelBar: "max", + ModelHallucinationLevelBar: "max", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": cfg.Matcher{Exact: "strict-user"}, + "riskAction": "block", + }, + }, + } + data := cfg.Data{ + RiskLevel: "none", + Detail: []cfg.Detail{ + {Suggestion: "mask", Type: "sensitiveData", Level: "S2", + Result: []cfg.Result{{Ext: cfg.Ext{Desensitization: "masked"}}}}, + }, + } + // strict-user 使用 block 模式,Level=S2 >= S2 但 Suggestion=mask + dimAction=block → detailTriggersBlock 返回 false(mask suggestion 不触发 block) + // 实际上 detailTriggersBlock: Suggestion != "block", dimAction == "block" → return exceeds + // exceeds = S2 >= S2 = true → RiskBlock + // 所以 strict-user 应该是 RiskBlock + require.Equal(t, cfg.RiskBlock, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "strict-user")) + // other-user 使用全局 mask 模式,Level=S2 >= S2 → RiskMask + require.Equal(t, cfg.RiskMask, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "other-user")) + }) +} + +func TestParseConsumerRiskActionValidation(t *testing.T) { + // 测试消费者级别 riskAction 无效值 + t.Run("invalid consumer riskAction", func(t *testing.T) { + config := cfg.AISecurityConfig{} + config.SetDefaultValues() + configJSON := `{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "consumerRiskLevel": [ + {"name": "user1", "matchType": "exact", "riskAction": "invalid"} + ] + }` + err := config.Parse(gjson.Parse(configJSON)) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid riskAction in consumerRiskLevel") + }) + + // 测试消费者级别 riskAction 有效值 + t.Run("valid consumer riskAction", func(t *testing.T) { + config := cfg.AISecurityConfig{} + config.SetDefaultValues() + configJSON := `{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "consumerRiskLevel": [ + {"name": "user1", "matchType": "exact", "riskAction": "mask"} + ] + }` + err := config.Parse(gjson.Parse(configJSON)) + require.NoError(t, err) + require.Equal(t, "mask", config.GetRiskAction("user1")) + require.Equal(t, "block", config.GetRiskAction("other")) + }) +} + func TestRiskLevelFunctions(t *testing.T) { // 测试风险等级转换函数 t.Run("risk level conversion", func(t *testing.T) { @@ -779,6 +989,1193 @@ func TestUtilityFunctions(t *testing.T) { }) } +func TestCustomLabelConstant(t *testing.T) { + // 验证 CustomLabelType 常量值 + require.Equal(t, "customLabel", cfg.CustomLabelType) +} + +func TestCustomLabelConfigParsing(t *testing.T) { + // 测试 customLabelLevelBar 设置为 high + t.Run("customLabelLevelBar set to high", func(t *testing.T) { + config := cfg.AISecurityConfig{} + config.SetDefaultValues() + configJSON := `{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "customLabelLevelBar": "high" + }` + err := config.Parse(gjson.Parse(configJSON)) + require.NoError(t, err) + require.Equal(t, "high", config.CustomLabelLevelBar) + }) + + // 测试 customLabelLevelBar 设置为 max + t.Run("customLabelLevelBar set to max", func(t *testing.T) { + config := cfg.AISecurityConfig{} + config.SetDefaultValues() + configJSON := `{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "customLabelLevelBar": "max" + }` + err := config.Parse(gjson.Parse(configJSON)) + require.NoError(t, err) + require.Equal(t, "max", config.CustomLabelLevelBar) + }) + + // 测试 customLabelLevelBar 缺省时默认为 max + t.Run("customLabelLevelBar defaults to max", func(t *testing.T) { + config := cfg.AISecurityConfig{} + config.SetDefaultValues() + configJSON := `{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk" + }` + err := config.Parse(gjson.Parse(configJSON)) + require.NoError(t, err) + require.Equal(t, "max", config.CustomLabelLevelBar) + }) + + // 测试 customLabelLevelBar 无效值 + t.Run("customLabelLevelBar invalid value", func(t *testing.T) { + config := cfg.AISecurityConfig{} + config.SetDefaultValues() + configJSON := `{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "customLabelLevelBar": "invalid" + }` + err := config.Parse(gjson.Parse(configJSON)) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid customLabelLevelBar") + }) +} + +func TestGetCustomLabelLevelBar(t *testing.T) { + // 测试消费者精确匹配 + t.Run("consumer exact match", func(t *testing.T) { + config := cfg.AISecurityConfig{ + CustomLabelLevelBar: "max", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": cfg.Matcher{Exact: "exact-user"}, + "customLabelLevelBar": "low", + }, + }, + } + require.Equal(t, "low", config.GetCustomLabelLevelBar("exact-user")) + }) + + // 测试消费者前缀匹配 + t.Run("consumer prefix match", func(t *testing.T) { + config := cfg.AISecurityConfig{ + CustomLabelLevelBar: "max", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": cfg.Matcher{Prefix: "prefix-"}, + "customLabelLevelBar": "medium", + }, + }, + } + require.Equal(t, "medium", config.GetCustomLabelLevelBar("prefix-user")) + }) + + // 测试无匹配回退全局值 + t.Run("no match falls back to global", func(t *testing.T) { + config := cfg.AISecurityConfig{ + CustomLabelLevelBar: "high", + ConsumerRiskLevel: []map[string]interface{}{ + { + "matcher": cfg.Matcher{Exact: "other-user"}, + "customLabelLevelBar": "low", + }, + }, + } + require.Equal(t, "high", config.GetCustomLabelLevelBar("unmatched-user")) + }) +} + +func TestCustomLabelDetailExceedsThreshold(t *testing.T) { + // 需要 proxy-wasm host 环境,因为 evaluateRiskMultiModal 调用 proxywasm.LogInfof + opt := proxytest.NewEmulatorOption().WithVMContext(&types.DefaultVMContext{}) + _, reset := proxytest.NewHostEmulator(opt) + defer reset() + + // 测试 customLabel Level=high, threshold=high → 拦截 (true) + t.Run("level high threshold high blocks", func(t *testing.T) { + config := cfg.AISecurityConfig{ + RiskAction: "block", + ContentModerationLevelBar: "max", + PromptAttackLevelBar: "max", + SensitiveDataLevelBar: "S4", + MaliciousUrlLevelBar: "max", + ModelHallucinationLevelBar: "max", + CustomLabelLevelBar: "high", + } + data := cfg.Data{ + RiskLevel: "none", + Detail: []cfg.Detail{ + {Type: cfg.CustomLabelType, Level: "high"}, + }, + } + require.Equal(t, cfg.RiskBlock, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "")) + }) + + // 测试 customLabel Level=none, threshold=high → 不拦截 (false) + t.Run("level none threshold high passes", func(t *testing.T) { + config := cfg.AISecurityConfig{ + RiskAction: "block", + ContentModerationLevelBar: "max", + PromptAttackLevelBar: "max", + SensitiveDataLevelBar: "S4", + MaliciousUrlLevelBar: "max", + ModelHallucinationLevelBar: "max", + CustomLabelLevelBar: "high", + } + data := cfg.Data{ + RiskLevel: "none", + Detail: []cfg.Detail{ + {Type: cfg.CustomLabelType, Level: "none"}, + }, + } + require.Equal(t, cfg.RiskPass, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "")) + }) + + // 测试 customLabel Level=high, threshold=max → 不拦截 (false) + t.Run("level high threshold max passes", func(t *testing.T) { + config := cfg.AISecurityConfig{ + RiskAction: "block", + ContentModerationLevelBar: "max", + PromptAttackLevelBar: "max", + SensitiveDataLevelBar: "S4", + MaliciousUrlLevelBar: "max", + ModelHallucinationLevelBar: "max", + CustomLabelLevelBar: "max", + } + data := cfg.Data{ + RiskLevel: "none", + Detail: []cfg.Detail{ + {Type: cfg.CustomLabelType, Level: "high"}, + }, + } + require.Equal(t, cfg.RiskPass, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "")) + }) +} + +func TestCustomLabelConfigIntegration(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + // 测试 customLabelConfig 配置解析和消费者覆盖 + t.Run("customLabel config with consumer override", func(t *testing.T) { + host, status := test.NewTestHost(customLabelConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + + securityConfig := config.(*cfg.AISecurityConfig) + require.Equal(t, "high", securityConfig.CustomLabelLevelBar) + require.Equal(t, "low", securityConfig.GetCustomLabelLevelBar("exact-user")) + require.Equal(t, "medium", securityConfig.GetCustomLabelLevelBar("prefix-user")) + require.Equal(t, "high", securityConfig.GetCustomLabelLevelBar("unknown-user")) + }) + }) +} + +func TestRequestMasking(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // 测试请求阶段脱敏成功:riskAction=mask,API 返回 mask 建议,请求体被替换为脱敏内容 + t.Run("request masking success", func(t *testing.T) { + host, status := test.NewTestHost(maskConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "我的电话是13800138000"}]}` + action := host.CallOnHttpRequestBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + // API 返回 mask 建议,包含脱敏内容 + securityResponse := `{ + "Code": 200, + "Message": "Success", + "RequestId": "req-123", + "Data": { + "RiskLevel": "low", + "Detail": [{ + "Suggestion": "mask", + "Type": "sensitiveData", + "Level": "S3", + "Result": [{ + "Label": "phone_number", + "Confidence": 99.0, + "Ext": { + "Desensitization": "我的电话是1**********", + "SensitiveData": ["13800138000"] + } + }] + }] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + // 验证请求体被替换为脱敏内容 + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() + require.Equal(t, "我的电话是1**********", content) + + host.CompleteHttp() + }) + + // 测试脱敏内容为空时回退到拦截 + t.Run("empty desensitization falls back to block", func(t *testing.T) { + host, status := test.NewTestHost(maskConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "我的电话是13800138000"}]}` + action := host.CallOnHttpRequestBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + // API 返回 mask 建议,但 Desensitization 为空 + securityResponse := `{ + "Code": 200, + "Message": "Success", + "RequestId": "req-123", + "Data": { + "RiskLevel": "low", + "Detail": [{ + "Suggestion": "mask", + "Type": "sensitiveData", + "Level": "S3", + "Result": [{ + "Label": "phone_number", + "Confidence": 99.0, + "Ext": { + "Desensitization": "", + "SensitiveData": ["13800138000"] + } + }] + }] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + // Desensitization 为空时应回退到拦截,SendHttpResponse 被调用 + // 验证请求体未被修改(原始内容保持不变) + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() + require.Equal(t, "我的电话是13800138000", content) + }) + + // 测试 riskAction=block 时 mask 建议按现有逻辑处理(向后兼容) + t.Run("riskAction block keeps existing behavior for mask suggestion", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "我的电话是13800138000"}]}` + action := host.CallOnHttpRequestBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + // API 返回 mask 建议,但全局 riskAction 默认为 block + securityResponse := `{ + "Code": 200, + "Message": "Success", + "RequestId": "req-123", + "Data": { + "RiskLevel": "low", + "Detail": [{ + "Suggestion": "mask", + "Type": "sensitiveData", + "Level": "S2", + "Result": [{ + "Label": "phone_number", + "Confidence": 99.0, + "Ext": { + "Desensitization": "我的电话是1**********", + "SensitiveData": ["13800138000"] + } + }] + }] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + // riskAction=block 时,mask 建议按风险等级判断,low 级别应放行 + action = host.GetHttpStreamAction() + require.Equal(t, types.ActionContinue, action) + + // 请求体不应被脱敏修改 + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() + require.Equal(t, "我的电话是13800138000", content) + + host.CompleteHttp() + }) + + // 测试 riskAction=mask 时 block 建议优先拦截 + t.Run("block suggestion takes priority over mask", func(t *testing.T) { + host, status := test.NewTestHost(maskConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "违规内容"}]}` + action := host.CallOnHttpRequestBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + // API 返回 block 建议 + securityResponse := `{ + "Code": 200, + "Message": "Success", + "RequestId": "req-123", + "Data": { + "RiskLevel": "high", + "Detail": [{ + "Suggestion": "block", + "Type": "contentModeration", + "Level": "high" + }] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + // block 建议应拦截请求,请求体不应被修改 + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() + require.Equal(t, "违规内容", content) + }) + }) +} + +// 测试配置:MCP + 脱敏模式 +var mcpMaskConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "checkResponse": true, + "action": "MultiModalGuard", + "apiType": "mcp", + "riskAction": "mask", + "requestContentJsonPath": "params.arguments.input", + "responseContentJsonPath": "content", + "responseStreamContentJsonPath": "content", + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "sensitiveDataLevelBar": "S3", + "timeout": 2000, + }) + return data +}() + +func TestIsRiskLevelAcceptable(t *testing.T) { + // 需要 proxy-wasm host 环境,因为 evaluateRiskMultiModal 调用 proxywasm.LogInfof + opt := proxytest.NewEmulatorOption().WithVMContext(&types.DefaultVMContext{}) + _, reset := proxytest.NewHostEmulator(opt) + defer reset() + + // 用例 1: riskAction=mask, Suggestion=mask → 应返回 true(mask 不应被视为不可接受) + t.Run("mask action with mask suggestion is acceptable", func(t *testing.T) { + config := cfg.AISecurityConfig{ + RiskAction: "mask", + ContentModerationLevelBar: "max", + PromptAttackLevelBar: "max", + SensitiveDataLevelBar: "S4", + MaliciousUrlLevelBar: "max", + ModelHallucinationLevelBar: "max", + CustomLabelLevelBar: "max", + } + data := cfg.Data{ + RiskLevel: "none", + Detail: []cfg.Detail{ + {Suggestion: "mask", Type: "sensitiveData", Level: "S2", + Result: []cfg.Result{{Ext: cfg.Ext{Desensitization: "masked"}}}}, + }, + } + require.True(t, cfg.IsRiskLevelAcceptable(cfg.MultiModalGuard, data, config, "")) + }) + + // 用例 2: riskAction=mask, Suggestion=block → 应返回 false + t.Run("mask action with block suggestion is not acceptable", func(t *testing.T) { + config := cfg.AISecurityConfig{ + RiskAction: "mask", + ContentModerationLevelBar: "max", + PromptAttackLevelBar: "max", + SensitiveDataLevelBar: "S4", + MaliciousUrlLevelBar: "max", + ModelHallucinationLevelBar: "max", + CustomLabelLevelBar: "max", + } + data := cfg.Data{ + RiskLevel: "none", + Detail: []cfg.Detail{ + {Suggestion: "block", Type: "contentModeration", Level: "high"}, + }, + } + require.False(t, cfg.IsRiskLevelAcceptable(cfg.MultiModalGuard, data, config, "")) + }) + + // 用例 3: riskAction=mask, 无风险 → 应返回 true + t.Run("mask action with no risk is acceptable", func(t *testing.T) { + config := cfg.AISecurityConfig{ + RiskAction: "mask", + ContentModerationLevelBar: "max", + PromptAttackLevelBar: "max", + SensitiveDataLevelBar: "S4", + MaliciousUrlLevelBar: "max", + ModelHallucinationLevelBar: "max", + CustomLabelLevelBar: "max", + } + data := cfg.Data{RiskLevel: "low"} + require.True(t, cfg.IsRiskLevelAcceptable(cfg.MultiModalGuard, data, config, "")) + }) + + // 用例 4: riskAction=block, Suggestion=mask, level 未超阈值 → 应返回 true(向后兼容) + t.Run("block action with mask suggestion below threshold is acceptable", func(t *testing.T) { + config := cfg.AISecurityConfig{ + RiskAction: "block", + ContentModerationLevelBar: "max", + PromptAttackLevelBar: "max", + SensitiveDataLevelBar: "S4", + MaliciousUrlLevelBar: "max", + ModelHallucinationLevelBar: "max", + CustomLabelLevelBar: "max", + } + data := cfg.Data{ + RiskLevel: "none", + Detail: []cfg.Detail{ + {Suggestion: "mask", Type: "sensitiveData", Level: "S2"}, + }, + } + require.True(t, cfg.IsRiskLevelAcceptable(cfg.MultiModalGuard, data, config, "")) + }) + + // 用例 5: riskAction=block, Suggestion=mask, level 超阈值 → 应返回 false(向后兼容) + t.Run("block action with mask suggestion exceeding threshold is not acceptable", func(t *testing.T) { + config := cfg.AISecurityConfig{ + RiskAction: "block", + ContentModerationLevelBar: "max", + PromptAttackLevelBar: "max", + SensitiveDataLevelBar: "S2", + MaliciousUrlLevelBar: "max", + ModelHallucinationLevelBar: "max", + CustomLabelLevelBar: "max", + } + data := cfg.Data{ + RiskLevel: "none", + Detail: []cfg.Detail{ + {Suggestion: "mask", Type: "sensitiveData", Level: "S2"}, + }, + } + require.False(t, cfg.IsRiskLevelAcceptable(cfg.MultiModalGuard, data, config, "")) + }) + + // 用例 6: TextModerationPlus, riskAction=mask → 不受影响 + t.Run("TextModerationPlus not affected by riskAction mask", func(t *testing.T) { + config := cfg.AISecurityConfig{ + RiskAction: "mask", + RiskLevelBar: "high", + } + data := cfg.Data{RiskLevel: "low"} + require.True(t, cfg.IsRiskLevelAcceptable(cfg.TextModerationPlus, data, config, "")) + }) +} + +func TestMcpMaskNotBlock(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // 用例 7: MCP 请求, riskAction=mask, API 返回 Suggestion=mask → 应放行 + t.Run("mcp request with mask suggestion should pass not block", func(t *testing.T) { + host, status := test.NewTestHost(mcpMaskConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/mcp"}, + {":method", "POST"}, + }) + + body := `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"test","arguments":{"input":"我的电话是13800138000"}}}` + action := host.CallOnHttpRequestBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + // API 返回 mask 建议 + securityResponse := `{ + "Code": 200, + "Message": "Success", + "RequestId": "req-123", + "Data": { + "RiskLevel": "low", + "Detail": [{ + "Suggestion": "mask", + "Type": "sensitiveData", + "Level": "S2", + "Result": [{ + "Label": "phone_number", + "Confidence": 99.0, + "Ext": { + "Desensitization": "我的电话是1**********", + "SensitiveData": ["13800138000"] + } + }] + }] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + // 应放行而非拦截 + action = host.GetHttpStreamAction() + require.Equal(t, types.ActionContinue, action) + host.CompleteHttp() + }) + + // 用例 8: MCP 响应, riskAction=mask, API 返回 Suggestion=mask → 应放行 + t.Run("mcp response with mask suggestion should pass not block", func(t *testing.T) { + host, status := test.NewTestHost(mcpMaskConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/mcp"}, + {":method", "POST"}, + }) + + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }) + + body := `{"content": "我的电话是13800138000"}` + action := host.CallOnHttpResponseBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + // API 返回 mask 建议 + securityResponse := `{ + "Code": 200, + "Message": "Success", + "RequestId": "req-123", + "Data": { + "RiskLevel": "low", + "Detail": [{ + "Suggestion": "mask", + "Type": "sensitiveData", + "Level": "S2", + "Result": [{ + "Label": "phone_number", + "Confidence": 99.0, + "Ext": { + "Desensitization": "我的电话是1**********", + "SensitiveData": ["13800138000"] + } + }] + }] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + // 应放行而非拦截 + action = host.GetHttpStreamAction() + require.Equal(t, types.ActionContinue, action) + host.CompleteHttp() + }) + + // 用例 9: MCP 请求, riskAction=mask, API 返回 Suggestion=block → 应拦截 + t.Run("mcp request with block suggestion should deny", func(t *testing.T) { + host, status := test.NewTestHost(mcpMaskConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/mcp"}, + {":method", "POST"}, + }) + + body := `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"test","arguments":{"input":"违规内容"}}}` + action := host.CallOnHttpRequestBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + // API 返回 block 建议 + securityResponse := `{ + "Code": 200, + "Message": "Success", + "RequestId": "req-123", + "Data": { + "RiskLevel": "high", + "Detail": [{ + "Suggestion": "block", + "Type": "contentModeration", + "Level": "high" + }] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + // block 建议应拦截(SendHttpResponse 被调用,请求不会继续) + // MCP handler 调用 SendHttpResponse 后不会 resume + }) + }) +} + +// ============================================================================= +// TC-PARSE: 配置解析与校验集成测试 +// ============================================================================= + +// 测试配置:MultiModalGuard + 全局维度动作全为合法值 +var dimensionActionValidConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "checkResponse": true, + "action": "MultiModalGuard", + "contentModerationAction": "block", + "promptAttackAction": "block", + "sensitiveDataAction": "mask", + "maliciousUrlAction": "block", + "modelHallucinationAction": "block", + "customLabelAction": "block", + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "sensitiveDataLevelBar": "S3", + }) + return data +}() + +// 测试配置:MultiModalGuard + 全局维度动作出现非法值 +var dimensionActionInvalidConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "checkResponse": true, + "action": "MultiModalGuard", + "contentModerationAction": "allow", // 非法值 + }) + return data +}() + +// 测试配置:MultiModalGuard + consumerRiskLevel 内维度动作非法 +var consumerDimensionActionInvalidConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "action": "MultiModalGuard", + "consumerRiskLevel": []map[string]interface{}{ + { + "name": "user-a", + "matchType": "exact", + "sensitiveDataAction": "deny", // 非法值 + }, + }, + }) + return data +}() + +// 测试配置:TextModerationPlus + 配置了维度动作 +var textModPlusDimensionActionConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "action": "TextModerationPlus", + "sensitiveDataAction": "mask", + "contentModerationAction": "block", + }) + return data +}() + +// 测试配置:未配置任何动作字段 +var noActionFieldConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + }) + return data +}() + +// TestTC_PARSE_001 MultiModalGuard + 全局维度动作全为合法值 => 启动成功 +func TestTC_PARSE_001(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + host, status := test.NewTestHost(dimensionActionValidConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + securityConfig := config.(*cfg.AISecurityConfig) + require.Equal(t, "block", securityConfig.ContentModerationAction) + require.Equal(t, "block", securityConfig.PromptAttackAction) + require.Equal(t, "mask", securityConfig.SensitiveDataAction) + require.Equal(t, "block", securityConfig.MaliciousUrlAction) + require.Equal(t, "block", securityConfig.ModelHallucinationAction) + require.Equal(t, "block", securityConfig.CustomLabelAction) + }) +} + +// TestTC_PARSE_002 MultiModalGuard + 全局维度动作出现非法值 => 启动失败 +func TestTC_PARSE_002(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + host, status := test.NewTestHost(dimensionActionInvalidConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) +} + +// TestTC_PARSE_003 MultiModalGuard + consumerRiskLevel 内维度动作非法 => 启动失败 +func TestTC_PARSE_003(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + host, status := test.NewTestHost(consumerDimensionActionInvalidConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) +} + +// TestTC_PARSE_004 TextModerationPlus + 配置了维度动作 => 启动成功(字段忽略) +func TestTC_PARSE_004(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + host, status := test.NewTestHost(textModPlusDimensionActionConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + securityConfig := config.(*cfg.AISecurityConfig) + // 字段被解析但在运行时被忽略(不影响启动) + require.Equal(t, "mask", securityConfig.SensitiveDataAction) + require.Equal(t, "block", securityConfig.ContentModerationAction) + }) +} + +// TestTC_PARSE_005 未配置任何动作字段 => 默认 riskAction=block +func TestTC_PARSE_005(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + host, status := test.NewTestHost(noActionFieldConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + securityConfig := config.(*cfg.AISecurityConfig) + require.Equal(t, "block", securityConfig.RiskAction) + require.Equal(t, "", securityConfig.ContentModerationAction) + require.Equal(t, "", securityConfig.PromptAttackAction) + require.Equal(t, "", securityConfig.SensitiveDataAction) + require.Equal(t, "", securityConfig.MaliciousUrlAction) + require.Equal(t, "", securityConfig.ModelHallucinationAction) + require.Equal(t, "", securityConfig.CustomLabelAction) + }) +} + +// TestTC_PARSE_006 TextModerationPlus + 非法维度动作值 => 启动成功(需求 8.2) +func TestTC_PARSE_006(t *testing.T) { + // 非 MultiModalGuard 下配置了非法维度动作值,不应报错,应启动成功 + invalidDimActionTextModConfig := func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "action": "TextModerationPlus", + "contentModerationAction": "allow", // 非法值,但非 MultiModalGuard 下应忽略 + "sensitiveDataAction": "deny", // 非法值 + }) + return data + }() + test.RunGoTest(t, func(t *testing.T) { + host, status := test.NewTestHost(invalidDimActionTextModConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + }) +} + +// TestTC_PARSE_007 TextModerationPlus + consumerRiskLevel 内非法维度动作值 => 启动成功(需求 8.2) +func TestTC_PARSE_007(t *testing.T) { + invalidConsumerDimActionTextModConfig := func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "action": "TextModerationPlus", + "consumerRiskLevel": []map[string]interface{}{ + { + "name": "user-a", + "matchType": "exact", + "sensitiveDataAction": "invalid-value", // 非法值,但非 MultiModalGuard 下应忽略 + }, + }, + }) + return data + }() + test.RunGoTest(t, func(t *testing.T) { + host, status := test.NewTestHost(invalidConsumerDimActionTextModConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + }) +} + +// ============================================================================= +// TC-REG: 回归测试 +// ============================================================================= + +// 测试配置:历史仅 riskAction=block 的 MultiModalGuard 配置(无维度动作字段) +var legacyRiskActionBlockConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "checkResponse": true, + "action": "MultiModalGuard", + "riskAction": "block", + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "sensitiveDataLevelBar": "S3", + "timeout": 2000, + }) + return data +}() + +// 测试配置:历史仅 riskAction=mask 的 MultiModalGuard 配置(无维度动作字段) +var legacyRiskActionMaskConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "checkResponse": true, + "action": "MultiModalGuard", + "riskAction": "mask", + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "sensitiveDataLevelBar": "S3", + "timeout": 2000, + }) + return data +}() + +// TestTC_REG_004 历史仅 riskAction 配置的场景不回归 +// 验证:仅配置 riskAction(不配置任何维度动作字段)时,新代码行为与历史一致 +func TestTC_REG_004(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // 子用例 1: riskAction=block,请求安全检查通过(低风险)=> 放行 + t.Run("legacy block config pass on low risk", func(t *testing.T) { + host, status := test.NewTestHost(legacyRiskActionBlockConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "Hello"}]}` + action := host.CallOnHttpRequestBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + // API 返回低风险,无 Detail 触发 + securityResponse := `{ + "Code": 200, + "Message": "Success", + "RequestId": "req-reg-001", + "Data": { + "RiskLevel": "none", + "Detail": [] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + action = host.GetHttpStreamAction() + require.Equal(t, types.ActionContinue, action) + host.CompleteHttp() + }) + + // 子用例 2: riskAction=block,顶层 RiskLevel 超阈值 => 拦截 + t.Run("legacy block config blocks on high risk level", func(t *testing.T) { + host, status := test.NewTestHost(legacyRiskActionBlockConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "违规内容"}]}` + action := host.CallOnHttpRequestBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + // API 返回高风险,顶层 RiskLevel 超阈值 + securityResponse := `{ + "Code": 200, + "Message": "Success", + "RequestId": "req-reg-002", + "Data": { + "RiskLevel": "high", + "Detail": [{ + "Suggestion": "block", + "Type": "contentModeration", + "Level": "high" + }] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + // 拦截:请求体不应被修改 + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() + require.Equal(t, "违规内容", content) + }) + + // 子用例 3: riskAction=block,Detail 有 mask 建议但 level 未超阈值 => 放行(block 模式忽略 mask) + t.Run("legacy block config ignores mask suggestion below threshold", func(t *testing.T) { + host, status := test.NewTestHost(legacyRiskActionBlockConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "我的电话是13800138000"}]}` + action := host.CallOnHttpRequestBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + // API 返回 mask 建议,但 riskAction=block 且 level 未超阈值 + securityResponse := `{ + "Code": 200, + "Message": "Success", + "RequestId": "req-reg-003", + "Data": { + "RiskLevel": "none", + "Detail": [{ + "Suggestion": "mask", + "Type": "sensitiveData", + "Level": "S2", + "Result": [{ + "Label": "phone_number", + "Confidence": 99.0, + "Ext": { + "Desensitization": "我的电话是1**********", + "SensitiveData": ["13800138000"] + } + }] + }] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + // riskAction=block 时,mask 建议不触发脱敏,level 未超阈值应放行 + action = host.GetHttpStreamAction() + require.Equal(t, types.ActionContinue, action) + + // 请求体不应被脱敏修改 + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() + require.Equal(t, "我的电话是13800138000", content) + host.CompleteHttp() + }) + + // 子用例 4: riskAction=mask,Detail 有 mask 建议 => 脱敏替换 + t.Run("legacy mask config applies desensitization", func(t *testing.T) { + host, status := test.NewTestHost(legacyRiskActionMaskConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "我的电话是13800138000"}]}` + action := host.CallOnHttpRequestBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + // API 返回 mask 建议,riskAction=mask 应触发脱敏 + securityResponse := `{ + "Code": 200, + "Message": "Success", + "RequestId": "req-reg-004", + "Data": { + "RiskLevel": "none", + "Detail": [{ + "Suggestion": "mask", + "Type": "sensitiveData", + "Level": "S3", + "Result": [{ + "Label": "phone_number", + "Confidence": 99.0, + "Ext": { + "Desensitization": "我的电话是1**********", + "SensitiveData": ["13800138000"] + } + }] + }] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + // riskAction=mask 时,mask 建议应触发脱敏替换 + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() + require.Equal(t, "我的电话是1**********", content) + host.CompleteHttp() + }) + + // 子用例 5: riskAction=mask,Detail 有 block 建议 => 仍然拦截(block 优先) + t.Run("legacy mask config still blocks on block suggestion", func(t *testing.T) { + host, status := test.NewTestHost(legacyRiskActionMaskConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "违规内容"}]}` + action := host.CallOnHttpRequestBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + // API 返回 block 建议,即使 riskAction=mask 也应拦截 + securityResponse := `{ + "Code": 200, + "Message": "Success", + "RequestId": "req-reg-005", + "Data": { + "RiskLevel": "high", + "Detail": [{ + "Suggestion": "block", + "Type": "contentModeration", + "Level": "high" + }] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + // block 建议应拦截,请求体不应被修改 + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() + require.Equal(t, "违规内容", content) + }) + }) +} func TestMultiModalGuardTextGenerationDeny(t *testing.T) { test.RunTest(t, func(t *testing.T) { // MultiModalGuard text_generation request deny → exercises multi_modal_guard/text/openai.go BuildDenyResponseBody path @@ -805,7 +2202,6 @@ func TestMultiModalGuardTextGenerationDeny(t *testing.T) { local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for request deny") require.Contains(t, string(local.Data), "blockedDetails") - require.Contains(t, string(local.Data), "req-mmg-text-deny") }) // MultiModalGuard text_generation response deny → exercises common/text/openai.go HandleTextGenerationResponseBody BuildDenyResponseBody path @@ -838,7 +2234,6 @@ func TestMultiModalGuardTextGenerationDeny(t *testing.T) { local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for response deny") require.Contains(t, string(local.Data), "blockedDetails") - require.Contains(t, string(local.Data), "req-mmg-resp-deny") }) // MultiModalGuard text_generation request pass @@ -895,7 +2290,6 @@ func TestMultiModalGuardImageGenerationDeny(t *testing.T) { local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for OpenAI image request deny") require.Contains(t, string(local.Data), "blockedDetails") - require.Contains(t, string(local.Data), "req-img-openai-deny") }) // OpenAI image generation request pass @@ -948,7 +2342,6 @@ func TestMultiModalGuardImageGenerationDeny(t *testing.T) { local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for Qwen image request deny") require.Contains(t, string(local.Data), "blockedDetails") - require.Contains(t, string(local.Data), "req-img-qwen-deny") }) // Qwen image generation request pass @@ -1005,7 +2398,6 @@ func TestMCPRequestDeny(t *testing.T) { local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for MCP request deny") require.Contains(t, string(local.Data), "blockedDetails") - require.Contains(t, string(local.Data), "req-mcp-deny") }) // MCP request pass @@ -1085,7 +2477,6 @@ func TestTextModerationPlusResponseDeny(t *testing.T) { local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for response deny") require.Contains(t, string(local.Data), "blockedDetails") - require.Contains(t, string(local.Data), "req-tmp-resp-deny") // Verify OpenAI completion shape wrapper type openAIChatCompletion struct { @@ -1101,8 +2492,7 @@ func TestTextModerationPlusResponseDeny(t *testing.T) { var deny cfg.DenyResponseBody require.NoError(t, json.Unmarshal([]byte(outer.Choices[0].Message.Content), &deny)) - require.Equal(t, "req-tmp-resp-deny", deny.RequestId) - require.Equal(t, 200, deny.GuardCode) + require.Equal(t, 200, deny.Code) require.NotEmpty(t, deny.BlockedDetails) }) }) @@ -1111,16 +2501,18 @@ func TestTextModerationPlusResponseDeny(t *testing.T) { func TestBuildDenyResponseBody(t *testing.T) { makeConfig := func(contentBar, promptBar string) cfg.AISecurityConfig { return cfg.AISecurityConfig{ - ContentModerationLevelBar: contentBar, - PromptAttackLevelBar: promptBar, - SensitiveDataLevelBar: "S4", - MaliciousUrlLevelBar: "max", + ContentModerationLevelBar: contentBar, + PromptAttackLevelBar: promptBar, + SensitiveDataLevelBar: "S4", + MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", - Action: cfg.MultiModalGuard, + CustomLabelLevelBar: "max", + RiskAction: "block", + Action: cfg.MultiModalGuard, } } - t.Run("guardCode equals response.Code", func(t *testing.T) { + t.Run("code equals response.Code", func(t *testing.T) { resp := cfg.Response{ Code: 200, RequestId: "req-123", @@ -1131,8 +2523,7 @@ func TestBuildDenyResponseBody(t *testing.T) { var result cfg.DenyResponseBody require.NoError(t, json.Unmarshal(body, &result)) - require.Equal(t, 200, result.GuardCode) - require.Equal(t, "req-123", result.RequestId) + require.Equal(t, 200, result.Code) }) t.Run("blockedDetails from Data.Detail", func(t *testing.T) { @@ -1142,7 +2533,7 @@ func TestBuildDenyResponseBody(t *testing.T) { Data: cfg.Data{ Detail: []cfg.Detail{ {Type: cfg.ContentModerationType, Level: "high", Suggestion: "block"}, - {Type: cfg.PromptAttackType, Level: "low", Suggestion: "block"}, + {Type: cfg.PromptAttackType, Level: "low", Suggestion: "none"}, }, }, } @@ -1158,6 +2549,50 @@ func TestBuildDenyResponseBody(t *testing.T) { require.Equal(t, "high", result.BlockedDetails[0].Level) }) + t.Run("blockedDetails includes explicit block suggestion below threshold", func(t *testing.T) { + resp := cfg.Response{ + Code: 200, + RequestId: "req-suggestion-block", + Data: cfg.Data{ + Detail: []cfg.Detail{ + {Type: cfg.SensitiveDataType, Level: "S3", Suggestion: "block"}, + }, + }, + } + config := makeConfig("high", "high") + config.SensitiveDataLevelBar = "S4" + body, err := cfg.BuildDenyResponseBody(resp, config, "") + require.NoError(t, err) + + var result cfg.DenyResponseBody + require.NoError(t, json.Unmarshal(body, &result)) + require.Len(t, result.BlockedDetails, 1) + require.Equal(t, cfg.SensitiveDataType, result.BlockedDetails[0].Type) + require.Equal(t, "S3", result.BlockedDetails[0].Level) + }) + + t.Run("blockedDetails includes customLabel when threshold exceeded", func(t *testing.T) { + resp := cfg.Response{ + Code: 200, + RequestId: "req-custom-label", + Data: cfg.Data{ + Detail: []cfg.Detail{ + {Type: cfg.CustomLabelType, Level: "high", Suggestion: "none"}, + }, + }, + } + config := makeConfig("high", "high") + config.CustomLabelLevelBar = "high" + body, err := cfg.BuildDenyResponseBody(resp, config, "") + require.NoError(t, err) + + var result cfg.DenyResponseBody + require.NoError(t, json.Unmarshal(body, &result)) + require.Len(t, result.BlockedDetails, 1) + require.Equal(t, cfg.CustomLabelType, result.BlockedDetails[0].Type) + require.Equal(t, "high", result.BlockedDetails[0].Level) + }) + t.Run("blockedDetails fallback from RiskLevel when Detail is empty", func(t *testing.T) { resp := cfg.Response{ Code: 200, @@ -1176,7 +2611,6 @@ func TestBuildDenyResponseBody(t *testing.T) { require.NotEmpty(t, result.BlockedDetails, "expected fallback detail from RiskLevel") require.Equal(t, cfg.ContentModerationType, result.BlockedDetails[0].Type) require.Equal(t, "high", result.BlockedDetails[0].Level) - require.Equal(t, "block", result.BlockedDetails[0].Suggestion) }) t.Run("blockedDetails fallback from AttackLevel when Detail is empty", func(t *testing.T) { @@ -1197,7 +2631,6 @@ func TestBuildDenyResponseBody(t *testing.T) { require.NotEmpty(t, result.BlockedDetails, "expected fallback detail from AttackLevel") require.Equal(t, cfg.PromptAttackType, result.BlockedDetails[0].Type) require.Equal(t, "high", result.BlockedDetails[0].Level) - require.Equal(t, "block", result.BlockedDetails[0].Suggestion) }) t.Run("blockedDetails empty when risk levels below threshold", func(t *testing.T) { @@ -1219,3 +2652,436 @@ func TestBuildDenyResponseBody(t *testing.T) { require.Empty(t, result.BlockedDetails) }) } + +func TestBuildDenyResponseBody_WithDenyMessage(t *testing.T) { + config := cfg.AISecurityConfig{ + ContentModerationLevelBar: "high", + PromptAttackLevelBar: "high", + SensitiveDataLevelBar: "S4", + MaliciousUrlLevelBar: "max", + ModelHallucinationLevelBar: "max", + CustomLabelLevelBar: "max", + RiskAction: "block", + Action: cfg.MultiModalGuard, + DenyMessage: "很抱歉,我无法回答您的问题", + } + resp := cfg.Response{ + Code: 200, + Data: cfg.Data{ + Detail: []cfg.Detail{ + {Type: cfg.ContentModerationType, Level: "high", Suggestion: "block"}, + }, + }, + } + body, err := cfg.BuildDenyResponseBody(resp, config, "") + require.NoError(t, err) + + var result cfg.DenyResponseBody + require.NoError(t, json.Unmarshal(body, &result)) + require.Equal(t, "很抱歉,我无法回答您的问题", result.DenyMessage) +} + +func TestBuildDenyResponseBody_WithoutDenyMessage(t *testing.T) { + config := cfg.AISecurityConfig{ + ContentModerationLevelBar: "high", + PromptAttackLevelBar: "high", + SensitiveDataLevelBar: "S4", + MaliciousUrlLevelBar: "max", + ModelHallucinationLevelBar: "max", + CustomLabelLevelBar: "max", + RiskAction: "block", + Action: cfg.MultiModalGuard, + } + resp := cfg.Response{ + Code: 200, + Data: cfg.Data{ + Detail: []cfg.Detail{ + {Type: cfg.ContentModerationType, Level: "high", Suggestion: "block"}, + }, + }, + } + body, err := cfg.BuildDenyResponseBody(resp, config, "") + require.NoError(t, err) + require.NotContains(t, string(body), "denyMessage") +} + +func TestBuildDenyResponseBody_BlockedDetailsOnlyTypeAndLevel(t *testing.T) { + config := cfg.AISecurityConfig{ + ContentModerationLevelBar: "high", + PromptAttackLevelBar: "high", + SensitiveDataLevelBar: "S4", + MaliciousUrlLevelBar: "max", + ModelHallucinationLevelBar: "max", + CustomLabelLevelBar: "max", + RiskAction: "block", + Action: cfg.MultiModalGuard, + } + resp := cfg.Response{ + Code: 200, + Data: cfg.Data{ + Detail: []cfg.Detail{ + {Type: cfg.ContentModerationType, Level: "high", Suggestion: "block", Result: []cfg.Result{{Label: "violence"}}}, + {Type: cfg.PromptAttackType, Level: "high", Suggestion: "block", Result: []cfg.Result{{Label: "injection"}}}, + }, + }, + } + body, err := cfg.BuildDenyResponseBody(resp, config, "") + require.NoError(t, err) + + var raw map[string]interface{} + require.NoError(t, json.Unmarshal(body, &raw)) + details := raw["blockedDetails"].([]interface{}) + require.Len(t, details, 2) + for _, entry := range details { + m := entry.(map[string]interface{}) + require.Len(t, m, 2, "each blockedDetail entry should have exactly 2 keys (type and level)") + require.Contains(t, m, "type") + require.Contains(t, m, "level") + } +} + +func TestBuildDenyResponseBody_CodeField(t *testing.T) { + config := cfg.AISecurityConfig{ + ContentModerationLevelBar: "high", + PromptAttackLevelBar: "high", + SensitiveDataLevelBar: "S4", + MaliciousUrlLevelBar: "max", + ModelHallucinationLevelBar: "max", + CustomLabelLevelBar: "max", + RiskAction: "block", + Action: cfg.MultiModalGuard, + } + resp := cfg.Response{ + Code: 200, + Data: cfg.Data{}, + } + body, err := cfg.BuildDenyResponseBody(resp, config, "") + require.NoError(t, err) + + var result cfg.DenyResponseBody + require.NoError(t, json.Unmarshal(body, &result)) + require.Equal(t, 200, result.Code) +} + +func TestBuildDenyResponseBody_NoRequestId(t *testing.T) { + config := cfg.AISecurityConfig{ + ContentModerationLevelBar: "high", + PromptAttackLevelBar: "high", + SensitiveDataLevelBar: "S4", + MaliciousUrlLevelBar: "max", + ModelHallucinationLevelBar: "max", + CustomLabelLevelBar: "max", + RiskAction: "block", + Action: cfg.MultiModalGuard, + } + resp := cfg.Response{ + Code: 200, + RequestId: "req-should-not-appear", + Data: cfg.Data{}, + } + body, err := cfg.BuildDenyResponseBody(resp, config, "") + require.NoError(t, err) + require.NotContains(t, string(body), "requestId") +} + +func TestBuildDenyResponseBody_FallbackSynthesis(t *testing.T) { + config := cfg.AISecurityConfig{ + ContentModerationLevelBar: "high", + PromptAttackLevelBar: "high", + SensitiveDataLevelBar: "S4", + MaliciousUrlLevelBar: "max", + ModelHallucinationLevelBar: "max", + CustomLabelLevelBar: "max", + RiskAction: "block", + Action: cfg.MultiModalGuard, + } + resp := cfg.Response{ + Code: 200, + Data: cfg.Data{ + RiskLevel: "high", + // No Detail entries — triggers fallback synthesis + }, + } + body, err := cfg.BuildDenyResponseBody(resp, config, "") + require.NoError(t, err) + + var raw map[string]interface{} + require.NoError(t, json.Unmarshal(body, &raw)) + details := raw["blockedDetails"].([]interface{}) + require.NotEmpty(t, details, "expected fallback synthesized entries") + for _, entry := range details { + m := entry.(map[string]interface{}) + require.Len(t, m, 2, "fallback blockedDetail entry should have exactly 2 keys (type and level)") + require.Contains(t, m, "type") + require.Contains(t, m, "level") + } +} + +// ============================================================================= +// TC-COVER: 覆盖率补充测试 +// ============================================================================= + +// TestMultiModalGuardStreamDeny 覆盖 openai.go RiskBlock 分支中 stream 响应格式路径 +func TestMultiModalGuardStreamDeny(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("stream request deny returns SSE format", func(t *testing.T) { + host, status := test.NewTestHost(multiModalGuardTextConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + // 请求体中包含 stream=true,触发 SSE 响应格式 + body := `{"messages": [{"role": "user", "content": "trigger deny"}], "stream": true}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-stream-deny", "Data": {"RiskLevel": "high"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local, "expected SendHttpResponse for stream deny") + require.Contains(t, string(local.Data), "blockedDetails") + // 验证 SSE content-type + foundSSE := false + for _, h := range local.Headers { + if h[0] == "content-type" { + require.Equal(t, "text/event-stream;charset=UTF-8", h[1]) + foundSSE = true + } + } + require.True(t, foundSSE, "expected SSE content-type header") + }) + }) +} + +// TestMultiModalGuardProtocolOriginalDeny 覆盖 openai.go RiskBlock 分支中 ProtocolOriginal 路径 +func TestMultiModalGuardProtocolOriginalDeny(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("protocol original deny returns raw blockedDetails JSON", func(t *testing.T) { + host, status := test.NewTestHost(protocolOriginalConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "trigger deny"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-proto-orig", "Data": {"RiskLevel": "high"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local, "expected SendHttpResponse for protocol original deny") + require.Contains(t, string(local.Data), "blockedDetails") + // ProtocolOriginal 直接返回 JSON,不包装 OpenAI 格式 + for _, h := range local.Headers { + if h[0] == "content-type" { + require.Equal(t, "application/json", h[1]) + } + } + // 响应体是原始 blockedDetails JSON,不含 OpenAI 包装 + require.False(t, gjson.GetBytes(local.Data, "choices").Exists(), "should not wrap in OpenAI format") + }) + }) +} + +// TestMultiModalGuardDenyWithAdvice 覆盖 openai.go RiskBlock 分支中 Advice != nil 路径 +func TestMultiModalGuardDenyWithAdvice(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("deny with advice sets riskLabel and riskWords attributes", func(t *testing.T) { + host, status := test.NewTestHost(multiModalGuardTextConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "trigger deny with advice"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + // 包含 Advice 和 Result 的安全服务响应 + securityResponse := `{ + "Code": 200, "Message": "Success", "RequestId": "req-advice-deny", + "Data": { + "RiskLevel": "high", + "Result": [{"Label": "porn", "RiskWords": "bad-word"}], + "Advice": [{"Answer": "blocked", "HitLabel": "porn"}], + "Detail": [{"Suggestion": "block", "Type": "contentModeration", "Level": "high"}] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local, "expected SendHttpResponse") + require.Contains(t, string(local.Data), "blockedDetails") + }) + }) +} + +// TestMultiChunkMasking 覆盖 openai.go 中 RiskPass + hasMasked 路径 +// 场景:内容超过 LengthLimit(1800),第一 chunk 触发 RiskMask 脱敏替换,第二 chunk RiskPass +func TestMultiChunkMasking(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("multi chunk masking with pass on second chunk", func(t *testing.T) { + host, status := test.NewTestHost(maskConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + // 生成超过 LengthLimit (1800) 的内容 + longContent := strings.Repeat("a", 2000) + body := `{"messages": [{"role": "user", "content": "` + longContent + `"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + // 第一个 chunk (1800 chars):返回 mask 建议及脱敏内容 + maskedChunk := strings.Repeat("b", 1800) + securityResponse1 := `{ + "Code": 200, "Message": "Success", "RequestId": "req-chunk-1", + "Data": { + "RiskLevel": "none", + "Detail": [{ + "Suggestion": "mask", "Type": "sensitiveData", "Level": "S3", + "Result": [{"Label": "phone", "Confidence": 99.0, + "Ext": {"Desensitization": "` + maskedChunk + `"}}] + }] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse1)) + + // 第二个 chunk (200 chars):返回 pass(无风险) + securityResponse2 := `{"Code": 200, "Message": "Success", "RequestId": "req-chunk-2", "Data": {"RiskLevel": "none", "Detail": []}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse2)) + + // 验证请求体被替换为脱敏内容 + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() + // 期望 = 脱敏后的第一 chunk (1800 'b') + 原始第二 chunk (200 'a') + expectedContent := maskedChunk + strings.Repeat("a", 200) + require.Equal(t, expectedContent, content) + + host.CompleteHttp() + }) + + // 覆盖 RiskMask 成功完成路径(单 chunk 内容刚好 <= LengthLimit,RiskMask 后立即完成) + // 该路径在 RiskMask 分支中 contentIndex >= len(maskedContent) 的子路径 + t.Run("single chunk mask completes in RiskMask branch", func(t *testing.T) { + host, status := test.NewTestHost(maskConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + body := `{"messages": [{"role": "user", "content": "我的银行卡号是6222021234567890"}]}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + securityResponse := `{ + "Code": 200, "Message": "Success", "RequestId": "req-single-mask", + "Data": { + "RiskLevel": "none", + "Detail": [{ + "Suggestion": "mask", "Type": "sensitiveData", "Level": "S3", + "Result": [{"Label": "bank_card", "Confidence": 99.0, + "Ext": {"Desensitization": "我的银行卡号是6222************"}}] + }] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() + require.Equal(t, "我的银行卡号是6222************", content) + + host.CompleteHttp() + }) + }) +} + +// TestMultiModalGuardMaskStreamDeny 覆盖 openai.go RiskMask 空脱敏 fallthrough 到 RiskBlock 的 stream 路径 +func TestMultiModalGuardMaskStreamDeny(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("mask with empty desensitization falls through to block stream format", func(t *testing.T) { + host, status := test.NewTestHost(maskConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + // stream=true 使 block 走 SSE 格式 + body := `{"messages": [{"role": "user", "content": "敏感内容"}], "stream": true}` + require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) + + // 返回 mask 建议但脱敏内容为空 → fallthrough 到 RiskBlock + securityResponse := `{ + "Code": 200, "Message": "Success", "RequestId": "req-mask-stream-deny", + "Data": { + "RiskLevel": "none", + "Detail": [{ + "Suggestion": "mask", "Type": "sensitiveData", "Level": "S3", + "Result": [{"Label": "phone", "Confidence": 99.0, + "Ext": {"Desensitization": ""}}] + }] + } + }` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local, "expected SendHttpResponse after mask fallthrough to block") + // 验证是 SSE 格式 + foundSSE := false + for _, h := range local.Headers { + if h[0] == "content-type" { + require.Equal(t, "text/event-stream;charset=UTF-8", h[1]) + foundSSE = true + } + } + require.True(t, foundSSE, "expected SSE content-type for stream deny") + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/utils/utils.go b/plugins/wasm-go/extensions/ai-security-guard/utils/utils.go index bc154efb7..f8edd0d7d 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/utils/utils.go +++ b/plugins/wasm-go/extensions/ai-security-guard/utils/utils.go @@ -4,11 +4,14 @@ import ( "bytes" "crypto/rand" "encoding/hex" + "fmt" mrand "math/rand" "strings" + "unicode/utf8" "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) func GenerateHexID(length int) (string, error) { @@ -41,3 +44,109 @@ func ExtractMessageFromStreamingBody(data []byte, jsonPath string) string { func GetConsumer(ctx wrapper.HttpContext) string { return ctx.GetStringContext("consumer", "") } + +func ReplaceJsonFieldContent(body []byte, jsonPath string, newContent string) ([]byte, error) { + return sjson.SetBytes(body, resolveJsonPath(body, jsonPath), newContent) +} + +// ReplaceJsonFieldTextContent replaces text content at jsonPath, handling both +// string and array (multimodal) content formats. When the field is an array +// (e.g. OpenAI multimodal content with text + image_url items), only the text +// items are updated while image_url and other items are preserved. +func ReplaceJsonFieldTextContent(body []byte, jsonPath string, newContent string) ([]byte, error) { + resolved := resolveJsonPath(body, jsonPath) + fieldValue := gjson.GetBytes(body, resolved) + if !fieldValue.IsArray() { + // Simple string content — replace directly + return sjson.SetBytes(body, resolved, newContent) + } + // Array content (multimodal): replace text items, preserve others + result := body + var err error + remaining := newContent + items := fieldValue.Array() + // Collect original text lengths for proportional splitting + type textEntry struct { + index int + text string + } + var textEntries []textEntry + totalTextLen := 0 + for i, item := range items { + if item.Get("type").String() == "text" { + t := item.Get("text").String() + textEntries = append(textEntries, textEntry{index: i, text: t}) + totalTextLen += utf8.RuneCountInString(t) + } + } + if len(textEntries) == 0 { + // No text items found, nothing to replace + return body, nil + } + // If there's only one text item, put all desensitized content there + if len(textEntries) == 1 { + itemPath := fmt.Sprintf("%s.%d.text", resolved, textEntries[0].index) + return sjson.SetBytes(result, itemPath, newContent) + } + // Multiple text items: split desensitized content proportionally by original lengths + for j, entry := range textEntries { + var replacement string + if j == len(textEntries)-1 { + // Last text item gets all remaining content + replacement = remaining + } else { + // Proportional split based on original text length (rune-aware) + var proportion int + if totalTextLen == 0 { + // All original text items are empty; roughly even with remainder on later segments + proportion = utf8.RuneCountInString(newContent) / len(textEntries) + } else { + proportion = utf8.RuneCountInString(entry.text) * utf8.RuneCountInString(newContent) / totalTextLen + } + runeCount := utf8.RuneCountInString(remaining) + if proportion > runeCount { + proportion = runeCount + } + // Convert rune count to byte offset to split at character boundary + byteOffset := 0 + for i := 0; i < proportion; i++ { + _, size := utf8.DecodeRuneInString(remaining[byteOffset:]) + byteOffset += size + } + replacement = remaining[:byteOffset] + remaining = remaining[byteOffset:] + } + itemPath := fmt.Sprintf("%s.%d.text", resolved, entry.index) + result, err = sjson.SetBytes(result, itemPath, replacement) + if err != nil { + return nil, err + } + } + return result, nil +} + +// resolveJsonPath converts gjson modifier paths (e.g. "messages.@reverse.0.content") +// into concrete index paths (e.g. "messages.2.content") that sjson can handle. +func resolveJsonPath(body []byte, jsonPath string) string { + parts := strings.Split(jsonPath, ".") + var resolved []string + for i := 0; i < len(parts); i++ { + if strings.HasPrefix(parts[i], "@reverse") && i+1 < len(parts) { + // Get the array at the path resolved so far + arrayPath := strings.Join(resolved, ".") + arrayLen := int(gjson.GetBytes(body, arrayPath+".#").Int()) + // Next part should be the reversed index + i++ + reversedIdx := 0 + fmt.Sscanf(parts[i], "%d", &reversedIdx) + actualIdx := arrayLen - 1 - reversedIdx + if actualIdx < 0 { + actualIdx = 0 + } + resolved = append(resolved, fmt.Sprintf("%d", actualIdx)) + } else { + resolved = append(resolved, parts[i]) + } + } + return strings.Join(resolved, ".") +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/utils/utils_test.go b/plugins/wasm-go/extensions/ai-security-guard/utils/utils_test.go new file mode 100644 index 000000000..6454b8b32 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/utils/utils_test.go @@ -0,0 +1,277 @@ +package utils + +import ( + "testing" + + "github.com/tidwall/gjson" +) + +func TestReplaceJsonFieldTextContent(t *testing.T) { + tests := []struct { + name string + body string + jsonPath string + newContent string + wantCheck func(t *testing.T, result []byte) + }{ + { + name: "string content replaced directly", + body: `{"messages":[{"role":"user","content":"我的电话是13800138000"}]}`, + jsonPath: "messages.0.content", + newContent: "我的电话是1**********", + wantCheck: func(t *testing.T, result []byte) { + got := gjson.GetBytes(result, "messages.0.content").String() + if got != "我的电话是1**********" { + t.Errorf("content = %q, want %q", got, "我的电话是1**********") + } + }, + }, + { + name: "array content preserves image_url items", + body: `{"messages":[{"role":"user","content":[{"type":"text","text":"我的电话是13800138000"},{"type":"image_url","image_url":{"url":"https://example.com/img.png"}}]}]}`, + jsonPath: "messages.0.content", + newContent: "我的电话是1**********", + wantCheck: func(t *testing.T, result []byte) { + content := gjson.GetBytes(result, "messages.0.content") + if !content.IsArray() { + t.Fatal("content should remain an array") + } + items := content.Array() + if len(items) != 2 { + t.Fatalf("expected 2 items, got %d", len(items)) + } + // text item updated + if items[0].Get("type").String() != "text" { + t.Error("first item type should be text") + } + if items[0].Get("text").String() != "我的电话是1**********" { + t.Errorf("text = %q, want %q", items[0].Get("text").String(), "我的电话是1**********") + } + // image_url item preserved + if items[1].Get("type").String() != "image_url" { + t.Error("second item type should be image_url") + } + if items[1].Get("image_url.url").String() != "https://example.com/img.png" { + t.Error("image_url should be preserved") + } + }, + }, + { + name: "array content with multiple text items", + body: `{"messages":[{"role":"user","content":[{"type":"text","text":"你好"},{"type":"text","text":"我的电话是13800138000"}]}]}`, + jsonPath: "messages.0.content", + newContent: "你好我的电话是1**********", + wantCheck: func(t *testing.T, result []byte) { + content := gjson.GetBytes(result, "messages.0.content") + if !content.IsArray() { + t.Fatal("content should remain an array") + } + items := content.Array() + if len(items) != 2 { + t.Fatalf("expected 2 items, got %d", len(items)) + } + // Both items should still be text type + combined := items[0].Get("text").String() + items[1].Get("text").String() + if combined != "你好我的电话是1**********" { + t.Errorf("combined text = %q, want %q", combined, "你好我的电话是1**********") + } + }, + }, + { + name: "array content with only image items returns body unchanged", + body: `{"messages":[{"role":"user","content":[{"type":"image_url","image_url":{"url":"https://example.com/a.png"}},{"type":"image_url","image_url":{"url":"https://example.com/b.png"}}]}]}`, + jsonPath: "messages.0.content", + newContent: "masked", + wantCheck: func(t *testing.T, result []byte) { + content := gjson.GetBytes(result, "messages.0.content") + items := content.Array() + if len(items) != 2 { + t.Fatalf("expected 2 items, got %d", len(items)) + } + for _, item := range items { + if item.Get("type").String() != "image_url" { + t.Error("all items should remain image_url") + } + } + }, + }, + { + name: "array content text before and after image", + body: `{"messages":[{"role":"user","content":[{"type":"text","text":"前缀"},{"type":"image_url","image_url":{"url":"https://img.com/1.png"}},{"type":"text","text":"后缀包含手机号13800138000"}]}]}`, + jsonPath: "messages.0.content", + newContent: "前缀后缀包含手机号1**********", + wantCheck: func(t *testing.T, result []byte) { + content := gjson.GetBytes(result, "messages.0.content") + items := content.Array() + if len(items) != 3 { + t.Fatalf("expected 3 items, got %d", len(items)) + } + if items[0].Get("type").String() != "text" { + t.Error("item 0 should be text") + } + if items[1].Get("type").String() != "image_url" { + t.Error("item 1 should be image_url") + } + if items[1].Get("image_url.url").String() != "https://img.com/1.png" { + t.Error("image_url should be preserved") + } + if items[2].Get("type").String() != "text" { + t.Error("item 2 should be text") + } + combined := items[0].Get("text").String() + items[2].Get("text").String() + if combined != "前缀后缀包含手机号1**********" { + t.Errorf("combined text = %q, want %q", combined, "前缀后缀包含手机号1**********") + } + }, + }, + { + name: "resolveJsonPath with @reverse", + body: `{"messages":[{"role":"system","content":"sys"},{"role":"user","content":"我的电话是13800138000"}]}`, + jsonPath: "messages.@reverse.0.content", + newContent: "我的电话是1**********", + wantCheck: func(t *testing.T, result []byte) { + // @reverse.0 should resolve to the last message (index 1) + got := gjson.GetBytes(result, "messages.1.content").String() + if got != "我的电话是1**********" { + t.Errorf("content = %q, want %q", got, "我的电话是1**********") + } + // system message should be untouched + sys := gjson.GetBytes(result, "messages.0.content").String() + if sys != "sys" { + t.Errorf("system content = %q, want %q", sys, "sys") + } + }, + }, + { + name: "multiple text items with CJK characters split at rune boundary", + body: `{"messages":[{"role":"user","content":[{"type":"text","text":"a"},{"type":"text","text":"bbbbbbbbb"}]}]}`, + jsonPath: "messages.0.content", + newContent: "你好12345678", + wantCheck: func(t *testing.T, result []byte) { + content := gjson.GetBytes(result, "messages.0.content") + items := content.Array() + if len(items) != 2 { + t.Fatalf("expected 2 items, got %d", len(items)) + } + // Each segment must be valid UTF-8 with no truncated characters + for i, item := range items { + txt := item.Get("text").String() + for _, r := range txt { + if r == '\uFFFD' { + t.Errorf("item %d contains replacement char U+FFFD, text=%q", i, txt) + } + } + } + combined := items[0].Get("text").String() + items[1].Get("text").String() + if combined != "你好12345678" { + t.Errorf("combined text = %q, want %q", combined, "你好12345678") + } + }, + }, + { + name: "multiple empty text items with non-empty newContent no panic", + body: `{"messages":[{"role":"user","content":[{"type":"text","text":""},{"type":"text","text":""},{"type":"image_url","image_url":{"url":"https://img.com/1.png"}}]}]}`, + jsonPath: "messages.0.content", + newContent: "脱敏后的内容abc", + wantCheck: func(t *testing.T, result []byte) { + content := gjson.GetBytes(result, "messages.0.content") + items := content.Array() + if len(items) != 3 { + t.Fatalf("expected 3 items, got %d", len(items)) + } + // image_url item preserved + if items[2].Get("type").String() != "image_url" { + t.Error("item 2 should be image_url") + } + // All newContent must be distributed across the two text items + combined := items[0].Get("text").String() + items[1].Get("text").String() + if combined != "脱敏后的内容abc" { + t.Errorf("combined text = %q, want %q", combined, "脱敏后的内容abc") + } + }, + }, + { + name: "resolveJsonPath with @reverse and array content", + body: `{"messages":[{"role":"system","content":"sys"},{"role":"user","content":[{"type":"text","text":"敏感内容"},{"type":"image_url","image_url":{"url":"https://img.com/x.png"}}]}]}`, + jsonPath: "messages.@reverse.0.content", + newContent: "脱敏内容", + wantCheck: func(t *testing.T, result []byte) { + content := gjson.GetBytes(result, "messages.1.content") + if !content.IsArray() { + t.Fatal("content should remain an array") + } + items := content.Array() + if len(items) != 2 { + t.Fatalf("expected 2 items, got %d", len(items)) + } + if items[0].Get("text").String() != "脱敏内容" { + t.Errorf("text = %q, want %q", items[0].Get("text").String(), "脱敏内容") + } + if items[1].Get("image_url.url").String() != "https://img.com/x.png" { + t.Error("image_url should be preserved") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ReplaceJsonFieldTextContent([]byte(tt.body), tt.jsonPath, tt.newContent) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Verify result is valid JSON + if !gjson.ValidBytes(result) { + t.Fatal("result is not valid JSON") + } + tt.wantCheck(t, result) + }) + } +} + +// TestResolveJsonPathEdgeCases covers edge cases in resolveJsonPath +func TestResolveJsonPathEdgeCases(t *testing.T) { + // @reverse with index exceeding array length → actualIdx clamped to 0 + t.Run("@reverse out of bounds index clamps to 0", func(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`) + // Array has 1 element (index 0). @reverse.5 → actualIdx = 0 - 5 = -5 → clamped to 0 + result, err := ReplaceJsonFieldTextContent(body, "messages.@reverse.5.content", "replaced") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !gjson.ValidBytes(result) { + t.Fatal("result is not valid JSON") + } + got := gjson.GetBytes(result, "messages.0.content").String() + if got != "replaced" { + t.Errorf("content = %q, want %q", got, "replaced") + } + }) + + // @reverse on empty array → actualIdx clamped to 0 + t.Run("@reverse on single-element array resolves correctly", func(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":"only one"}]}`) + // Array has 1 element. @reverse.0 → actualIdx = 1 - 1 - 0 = 0 + result, err := ReplaceJsonFieldTextContent(body, "messages.@reverse.0.content", "updated") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + got := gjson.GetBytes(result, "messages.0.content").String() + if got != "updated" { + t.Errorf("content = %q, want %q", got, "updated") + } + }) +} + +// TestReplaceJsonFieldContent covers the simple ReplaceJsonFieldContent function +func TestReplaceJsonFieldContent(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":"original"}]}`) + result, err := ReplaceJsonFieldContent(body, "messages.0.content", "replaced") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + got := gjson.GetBytes(result, "messages.0.content").String() + if got != "replaced" { + t.Errorf("content = %q, want %q", got, "replaced") + } +} diff --git a/plugins/wasm-go/mcp-servers/amap-tools/go.sum b/plugins/wasm-go/mcp-servers/amap-tools/go.sum index 0062a8784..b6a2c9f14 100644 --- a/plugins/wasm-go/mcp-servers/amap-tools/go.sum +++ b/plugins/wasm-go/mcp-servers/amap-tools/go.sum @@ -22,8 +22,10 @@ github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rR github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw= github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M= +github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= @@ -49,6 +51,7 @@ github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w= github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= diff --git a/plugins/wasm-go/mcp-servers/quark-search/go.sum b/plugins/wasm-go/mcp-servers/quark-search/go.sum index 0062a8784..b6a2c9f14 100644 --- a/plugins/wasm-go/mcp-servers/quark-search/go.sum +++ b/plugins/wasm-go/mcp-servers/quark-search/go.sum @@ -22,8 +22,10 @@ github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rR github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw= github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M= +github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= @@ -49,6 +51,7 @@ github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w= github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=