// Copyright (c) 2024 Alibaba Group Holding Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "encoding/json" "strings" "testing" cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/proxytest" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/test" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" ) // 测试配置:基础安全配置 var basicConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": true, "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, "bufferLimit": 1000, }) return data }() // 测试配置:仅检查请求 var requestOnlyConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": false, "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 1000, "bufferLimit": 500, }) return data }() // 测试配置:缺少必需字段 var missingRequiredConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "accessKey": "test-ak", "secretKey": "test-sk", // 故意缺少必需字段:serviceName, servicePort, serviceHost }) return data }() // 测试配置:缺少服务配置字段 var missingServiceConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": true, // 缺少 serviceName, servicePort, serviceHost }) return data }() // 测试配置:缺少认证字段 var missingAuthConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "checkRequest": true, "checkResponse": true, // 缺少 accessKey, secretKey }) return data }() // 测试配置:消费者级别特殊配置 var consumerSpecificConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": false, "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "maliciousUrlLevelBar": "high", "modelHallucinationLevelBar": "high", "timeout": 1000, "bufferLimit": 500, "consumerRequestCheckService": map[string]interface{}{ "name": "aaa", "matchType": "exact", "requestCheckService": "llm_query_moderation_1", }, "consumerResponseCheckService": map[string]interface{}{ "name": "bbb", "matchType": "prefix", "responseCheckService": "llm_response_moderation_1", }, "consumerRiskLevel": map[string]interface{}{ "name": "ccc.*", "matchType": "regexp", "maliciousUrlLevelBar": "low", }, }) return data }() // 测试配置:包含 customLabelLevelBar 和消费者级别覆盖 var customLabelConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": true, "action": "MultiModalGuard", "customLabelLevelBar": "high", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "consumerRiskLevel": []map[string]interface{}{ { "name": "exact-user", "matchType": "exact", "customLabelLevelBar": "low", }, { "name": "prefix-", "matchType": "prefix", "customLabelLevelBar": "medium", }, }, }) return data }() // 测试配置:脱敏模式配置(riskAction=mask) var maskConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": false, "action": "MultiModalGuard", "riskAction": "mask", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, }) return data }() // 测试配置:MCP配置 var mcpConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": false, "checkResponse": true, "action": "MultiModalGuard", "apiType": "mcp", "responseContentJsonPath": "content", "responseStreamContentJsonPath": "content", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, }) return data }() var mcpRequestConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": false, "action": "MultiModalGuard", "apiType": "mcp", "requestContentJsonPath": "params.arguments", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, }) return data }() // 测试配置:MultiModalGuard 文本生成 var multiModalGuardTextConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": true, "action": "MultiModalGuard", "apiType": "text_generation", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, "bufferLimit": 1000, }) return data }() // 测试配置:MultiModalGuard OpenAI 图像生成 var multiModalGuardImageOpenAIConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": true, "action": "MultiModalGuard", "apiType": "image_generation", "providerType": "openai", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, }) return data }() // 测试配置:MultiModalGuard Qwen 图像生成 var multiModalGuardImageQwenConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": true, "action": "MultiModalGuard", "apiType": "image_generation", "providerType": "qwen", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, }) return data }() // 测试配置:ProtocolOriginal MultiModalGuard var protocolOriginalConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": false, "action": "MultiModalGuard", "protocol": "original", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, }) return data }() func TestParseConfig(t *testing.T) { test.RunGoTest(t, func(t *testing.T) { // 测试基础配置解析 t.Run("basic config", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) config, err := host.GetMatchConfig() require.NoError(t, err) require.NotNil(t, config) securityConfig := config.(*cfg.AISecurityConfig) require.Equal(t, "test-ak", securityConfig.AK) require.Equal(t, "test-sk", securityConfig.SK) require.Equal(t, true, securityConfig.CheckRequest) require.Equal(t, true, securityConfig.CheckResponse) require.Equal(t, "high", securityConfig.ContentModerationLevelBar) require.Equal(t, "high", securityConfig.PromptAttackLevelBar) require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar) require.Equal(t, uint32(2000), securityConfig.Timeout) require.Equal(t, 1000, securityConfig.BufferLimit) }) // 测试仅检查请求的配置 t.Run("request only config", func(t *testing.T) { host, status := test.NewTestHost(requestOnlyConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) config, err := host.GetMatchConfig() require.NoError(t, err) require.NotNil(t, config) securityConfig := config.(*cfg.AISecurityConfig) require.Equal(t, true, securityConfig.CheckRequest) require.Equal(t, false, securityConfig.CheckResponse) require.Equal(t, "high", securityConfig.ContentModerationLevelBar) require.Equal(t, "high", securityConfig.PromptAttackLevelBar) require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar) }) // 测试缺少必需字段的配置 t.Run("missing required config", func(t *testing.T) { host, status := test.NewTestHost(missingRequiredConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusFailed, status) }) // 测试缺少服务配置字段 t.Run("missing service config", func(t *testing.T) { host, status := test.NewTestHost(missingServiceConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusFailed, status) }) // 测试缺少认证字段 t.Run("missing auth config", func(t *testing.T) { host, status := test.NewTestHost(missingAuthConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusFailed, status) }) // 测试消费者级别配置 t.Run("consumer specific config", func(t *testing.T) { host, status := test.NewTestHost(consumerSpecificConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) config, err := host.GetMatchConfig() require.NoError(t, err) require.NotNil(t, config) securityConfig := config.(*cfg.AISecurityConfig) require.Equal(t, "llm_query_moderation", securityConfig.GetRequestCheckService("aaaa")) require.Equal(t, "llm_query_moderation_1", securityConfig.GetRequestCheckService("aaa")) require.Equal(t, "llm_response_moderation", securityConfig.GetResponseCheckService("bb")) require.Equal(t, "llm_response_moderation_1", securityConfig.GetResponseCheckService("bbb-prefix-test")) require.Equal(t, "high", securityConfig.GetMaliciousUrlLevelBar("cc")) require.Equal(t, "low", securityConfig.GetMaliciousUrlLevelBar("ccc-regexp-test")) }) }) } func TestOnHttpRequestHeaders(t *testing.T) { test.RunTest(t, func(t *testing.T) { // 测试启用请求检查的情况 t.Run("request checking enabled", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 设置请求头 action := host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 应该返回ActionContinue require.Equal(t, types.ActionContinue, action) }) // 测试禁用请求检查的情况 t.Run("request checking disabled", func(t *testing.T) { host, status := test.NewTestHost(requestOnlyConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 设置请求头 action := host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 应该返回ActionContinue require.Equal(t, types.ActionContinue, action) }) }) } func TestOnHttpRequestBody(t *testing.T) { test.RunTest(t, func(t *testing.T) { // 测试请求体安全检查通过 t.Run("request body security check pass", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 先设置请求头 host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 设置请求体 body := `{"messages": [{"role": "user", "content": "Hello, how are you?"}]}` action := host.CallOnHttpRequestBody([]byte(body)) // 应该返回ActionPause,等待安全检查结果 require.Equal(t, types.ActionPause, action) // 模拟安全检查服务响应(通过) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) action = host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) // 测试空请求内容 t.Run("empty request content", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 先设置请求头 host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 设置空内容的请求体 body := `{"messages": [{"role": "user", "content": ""}]}` action := host.CallOnHttpRequestBody([]byte(body)) // 空内容应该直接通过 require.Equal(t, types.ActionContinue, action) }) }) } func TestOnHttpResponseHeaders(t *testing.T) { test.RunTest(t, func(t *testing.T) { // 测试启用响应检查的情况 t.Run("response checking enabled", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 先设置请求头 host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 设置响应头 action := host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) // 应该返回HeaderStopIteration require.Equal(t, types.HeaderStopIteration, action) }) // 测试禁用响应检查的情况 t.Run("response checking disabled", func(t *testing.T) { host, status := test.NewTestHost(requestOnlyConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 先设置请求头 host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 设置响应头 action := host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) // 应该返回ActionContinue require.Equal(t, types.ActionContinue, action) }) // 测试非200状态码 t.Run("non-200 status code", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 先设置请求头 host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 设置非200响应头 action := host.CallOnHttpResponseHeaders([][2]string{ {":status", "500"}, {"content-type", "application/json"}, }) // 应该返回ActionContinue require.Equal(t, types.ActionContinue, action) }) }) } func TestOnHttpResponseBody(t *testing.T) { test.RunTest(t, func(t *testing.T) { // 测试响应体安全检查通过 t.Run("response body security check pass", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 先设置请求头 host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 设置响应头 host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) // 设置响应体 body := `{"choices": [{"message": {"role": "assistant", "content": "Hello, how can I help you?"}}]}` action := host.CallOnHttpResponseBody([]byte(body)) // 应该返回ActionPause,等待安全检查结果 require.Equal(t, types.ActionPause, action) // 模拟安全检查服务响应(通过) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) action = host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) // 测试空响应内容 t.Run("empty response content", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 先设置请求头 host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 设置响应头 host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) // 设置空内容的响应体 body := `{"choices": [{"message": {"role": "assistant", "content": ""}}]}` action := host.CallOnHttpResponseBody([]byte(body)) // 空内容应该直接通过 require.Equal(t, types.ActionContinue, action) }) }) } func TestMCP(t *testing.T) { test.RunTest(t, func(t *testing.T) { // Test MCP Response Body Check - Pass t.Run("mcp response body security check pass", func(t *testing.T) { host, status := test.NewTestHost(mcpConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, {"x-mse-consumer", "test-user"}, }) host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) // body content matching responseContentJsonPath="content" body := `{"content": "Hello world"}` action := host.CallOnHttpResponseBody([]byte(body)) require.Equal(t, types.ActionPause, action) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) action = host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) // Test MCP Response Body Check - Deny t.Run("mcp response body security check deny", func(t *testing.T) { host, status := test.NewTestHost(mcpConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) body := `{"content": "Bad content"}` action := host.CallOnHttpResponseBody([]byte(body)) require.Equal(t, types.ActionPause, action) // High Risk securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) // Verify it was replaced with DenyResponse // Can't easily verify the replaced body content with current test wrapper but can check action // Since plugin calls SendHttpResponse, execution stops or changes. // mcp.go uses SendHttpResponse(..., DenyResponse, -1) which means it ends the stream. // We can check if GetHttpStreamAction is ActionPause (since it did send a response) or something else. // Actually SendHttpResponse in proxy-wasm usually terminates further processing of the original stream. }) // Test MCP Streaming Response Body Check - Pass t.Run("mcp streaming response body security check pass", func(t *testing.T) { host, status := test.NewTestHost(mcpConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "text/event-stream"}, }) // streaming chunk // config uses "content" key chunk := []byte(`data: {"content": "Hello"}` + "\n\n") // This calls OnHttpStreamingResponseBody -> mcp.HandleMcpStreamingResponseBody // It should push buffer and make call host.CallOnHttpStreamingResponseBody(chunk, false) // Action assertion removed as it returns an internal value 3 securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) }) // Test MCP Streaming Response Body Check - Deny t.Run("mcp streaming response body security check deny", func(t *testing.T) { host, status := test.NewTestHost(mcpConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "text/event-stream"}, }) chunk := []byte(`data: {"content": "Bad"}` + "\n\n") host.CallOnHttpStreamingResponseBody(chunk, false) // High Risk securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) // It injects DenySSEResponse. }) }) } func TestGetRiskAction(t *testing.T) { // 测试全局默认值 t.Run("default is block", func(t *testing.T) { config := cfg.AISecurityConfig{} config.SetDefaultValues() require.Equal(t, "block", config.GetRiskAction("any-consumer")) }) // 测试全局配置为 mask,无消费者覆盖 t.Run("global mask without consumer override", func(t *testing.T) { config := cfg.AISecurityConfig{RiskAction: "mask"} require.Equal(t, "mask", config.GetRiskAction("any-consumer")) }) // 测试消费者级别覆盖 riskAction t.Run("consumer overrides riskAction to mask", func(t *testing.T) { config := cfg.AISecurityConfig{ RiskAction: "block", ConsumerRiskLevel: []map[string]interface{}{ { "matcher": cfg.Matcher{Exact: "vip-user"}, "riskAction": "mask", }, }, } require.Equal(t, "mask", config.GetRiskAction("vip-user")) require.Equal(t, "block", config.GetRiskAction("normal-user")) }) // 测试消费者匹配但未配置 riskAction,fallback 到全局 t.Run("consumer matched without riskAction falls back to global", func(t *testing.T) { config := cfg.AISecurityConfig{ RiskAction: "mask", ConsumerRiskLevel: []map[string]interface{}{ { "matcher": cfg.Matcher{Exact: "some-user"}, "contentModerationLevelBar": "low", }, }, } require.Equal(t, "mask", config.GetRiskAction("some-user")) }) // 测试 prefix 匹配 t.Run("consumer prefix match", func(t *testing.T) { config := cfg.AISecurityConfig{ RiskAction: "block", ConsumerRiskLevel: []map[string]interface{}{ { "matcher": cfg.Matcher{Prefix: "test-"}, "riskAction": "mask", }, }, } require.Equal(t, "mask", config.GetRiskAction("test-user-1")) require.Equal(t, "block", config.GetRiskAction("prod-user")) }) // 测试空 consumer 不匹配任何规则 t.Run("empty consumer falls back to global", func(t *testing.T) { config := cfg.AISecurityConfig{ RiskAction: "mask", ConsumerRiskLevel: []map[string]interface{}{ { "matcher": cfg.Matcher{Exact: "vip"}, "riskAction": "block", }, }, } require.Equal(t, "mask", config.GetRiskAction("")) }) } func TestEvaluateRiskWithConsumerRiskAction(t *testing.T) { // 需要 proxy-wasm host 环境,因为 evaluateRiskMultiModal 调用 proxywasm.LogInfof opt := proxytest.NewEmulatorOption().WithVMContext(&types.DefaultVMContext{}) _, reset := proxytest.NewHostEmulator(opt) defer reset() // 测试全局 block,消费者 mask t.Run("global block consumer mask", func(t *testing.T) { config := cfg.AISecurityConfig{ RiskAction: "block", ContentModerationLevelBar: "max", PromptAttackLevelBar: "max", SensitiveDataLevelBar: "S4", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", ConsumerRiskLevel: []map[string]interface{}{ { "matcher": cfg.Matcher{Exact: "vip-user"}, "riskAction": "mask", "sensitiveDataLevelBar": "S2", }, }, } data := cfg.Data{ RiskLevel: "none", Detail: []cfg.Detail{ {Suggestion: "mask", Type: "sensitiveData", Level: "S2", Result: []cfg.Result{{Ext: cfg.Ext{Desensitization: "masked"}}}}, }, } // vip-user 使用 mask 模式,consumer 阈值 S2,Level=S2 >= S2 → RiskMask require.Equal(t, cfg.RiskMask, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "vip-user")) // normal-user 使用全局 block 模式,全局阈值 S4,Level=S2 < S4 → RiskPass require.Equal(t, cfg.RiskPass, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "normal-user")) }) // 测试全局 mask,消费者 block t.Run("global mask consumer block", func(t *testing.T) { config := cfg.AISecurityConfig{ RiskAction: "mask", ContentModerationLevelBar: "max", PromptAttackLevelBar: "max", SensitiveDataLevelBar: "S2", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", ConsumerRiskLevel: []map[string]interface{}{ { "matcher": cfg.Matcher{Exact: "strict-user"}, "riskAction": "block", }, }, } data := cfg.Data{ RiskLevel: "none", Detail: []cfg.Detail{ {Suggestion: "mask", Type: "sensitiveData", Level: "S2", Result: []cfg.Result{{Ext: cfg.Ext{Desensitization: "masked"}}}}, }, } // strict-user 使用 block 模式,Level=S2 >= S2 但 Suggestion=mask + dimAction=block → detailTriggersBlock 返回 false(mask suggestion 不触发 block) // 实际上 detailTriggersBlock: Suggestion != "block", dimAction == "block" → return exceeds // exceeds = S2 >= S2 = true → RiskBlock // 所以 strict-user 应该是 RiskBlock require.Equal(t, cfg.RiskBlock, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "strict-user")) // other-user 使用全局 mask 模式,Level=S2 >= S2 → RiskMask require.Equal(t, cfg.RiskMask, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "other-user")) }) } func TestParseConsumerRiskActionValidation(t *testing.T) { // 测试消费者级别 riskAction 无效值 t.Run("invalid consumer riskAction", func(t *testing.T) { config := cfg.AISecurityConfig{} config.SetDefaultValues() configJSON := `{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "consumerRiskLevel": [ {"name": "user1", "matchType": "exact", "riskAction": "invalid"} ] }` err := config.Parse(gjson.Parse(configJSON)) require.Error(t, err) require.Contains(t, err.Error(), "invalid riskAction in consumerRiskLevel") }) // 测试消费者级别 riskAction 有效值 t.Run("valid consumer riskAction", func(t *testing.T) { config := cfg.AISecurityConfig{} config.SetDefaultValues() configJSON := `{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "consumerRiskLevel": [ {"name": "user1", "matchType": "exact", "riskAction": "mask"} ] }` err := config.Parse(gjson.Parse(configJSON)) require.NoError(t, err) require.Equal(t, "mask", config.GetRiskAction("user1")) require.Equal(t, "block", config.GetRiskAction("other")) }) } func TestRiskLevelFunctions(t *testing.T) { // 测试风险等级转换函数 t.Run("risk level conversion", func(t *testing.T) { require.Equal(t, 4, cfg.LevelToInt(cfg.MaxRisk)) require.Equal(t, 3, cfg.LevelToInt(cfg.HighRisk)) require.Equal(t, 2, cfg.LevelToInt(cfg.MediumRisk)) require.Equal(t, 1, cfg.LevelToInt(cfg.LowRisk)) require.Equal(t, 0, cfg.LevelToInt(cfg.NoRisk)) require.Equal(t, -1, cfg.LevelToInt("invalid")) }) // 测试风险等级比较 t.Run("risk level comparison", func(t *testing.T) { require.True(t, cfg.LevelToInt(cfg.HighRisk) >= cfg.LevelToInt(cfg.MediumRisk)) require.True(t, cfg.LevelToInt(cfg.MediumRisk) >= cfg.LevelToInt(cfg.LowRisk)) require.True(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.NoRisk)) require.False(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.HighRisk)) }) } func TestUtilityFunctions(t *testing.T) { // 测试十六进制ID生成函数 t.Run("hex id generation", func(t *testing.T) { id, err := utils.GenerateHexID(16) require.NoError(t, err) require.Len(t, id, 16) require.Regexp(t, "^[0-9a-f]+$", id) }) // 测试随机ID生成函数 t.Run("random id generation", func(t *testing.T) { id := utils.GenerateRandomChatID() require.NotEmpty(t, id) require.Contains(t, id, "chatcmpl-") require.Len(t, id, 38) // "chatcmpl-" + 29 random chars }) } func TestCustomLabelConstant(t *testing.T) { // 验证 CustomLabelType 常量值 require.Equal(t, "customLabel", cfg.CustomLabelType) } func TestCustomLabelConfigParsing(t *testing.T) { // 测试 customLabelLevelBar 设置为 high t.Run("customLabelLevelBar set to high", func(t *testing.T) { config := cfg.AISecurityConfig{} config.SetDefaultValues() configJSON := `{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "customLabelLevelBar": "high" }` err := config.Parse(gjson.Parse(configJSON)) require.NoError(t, err) require.Equal(t, "high", config.CustomLabelLevelBar) }) // 测试 customLabelLevelBar 设置为 max t.Run("customLabelLevelBar set to max", func(t *testing.T) { config := cfg.AISecurityConfig{} config.SetDefaultValues() configJSON := `{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "customLabelLevelBar": "max" }` err := config.Parse(gjson.Parse(configJSON)) require.NoError(t, err) require.Equal(t, "max", config.CustomLabelLevelBar) }) // 测试 customLabelLevelBar 缺省时默认为 max t.Run("customLabelLevelBar defaults to max", func(t *testing.T) { config := cfg.AISecurityConfig{} config.SetDefaultValues() configJSON := `{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk" }` err := config.Parse(gjson.Parse(configJSON)) require.NoError(t, err) require.Equal(t, "max", config.CustomLabelLevelBar) }) // 测试 customLabelLevelBar 无效值 t.Run("customLabelLevelBar invalid value", func(t *testing.T) { config := cfg.AISecurityConfig{} config.SetDefaultValues() configJSON := `{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "customLabelLevelBar": "invalid" }` err := config.Parse(gjson.Parse(configJSON)) require.Error(t, err) require.Contains(t, err.Error(), "invalid customLabelLevelBar") }) } func TestGetCustomLabelLevelBar(t *testing.T) { // 测试消费者精确匹配 t.Run("consumer exact match", func(t *testing.T) { config := cfg.AISecurityConfig{ CustomLabelLevelBar: "max", ConsumerRiskLevel: []map[string]interface{}{ { "matcher": cfg.Matcher{Exact: "exact-user"}, "customLabelLevelBar": "low", }, }, } require.Equal(t, "low", config.GetCustomLabelLevelBar("exact-user")) }) // 测试消费者前缀匹配 t.Run("consumer prefix match", func(t *testing.T) { config := cfg.AISecurityConfig{ CustomLabelLevelBar: "max", ConsumerRiskLevel: []map[string]interface{}{ { "matcher": cfg.Matcher{Prefix: "prefix-"}, "customLabelLevelBar": "medium", }, }, } require.Equal(t, "medium", config.GetCustomLabelLevelBar("prefix-user")) }) // 测试无匹配回退全局值 t.Run("no match falls back to global", func(t *testing.T) { config := cfg.AISecurityConfig{ CustomLabelLevelBar: "high", ConsumerRiskLevel: []map[string]interface{}{ { "matcher": cfg.Matcher{Exact: "other-user"}, "customLabelLevelBar": "low", }, }, } require.Equal(t, "high", config.GetCustomLabelLevelBar("unmatched-user")) }) } func TestCustomLabelDetailExceedsThreshold(t *testing.T) { // 需要 proxy-wasm host 环境,因为 evaluateRiskMultiModal 调用 proxywasm.LogInfof opt := proxytest.NewEmulatorOption().WithVMContext(&types.DefaultVMContext{}) _, reset := proxytest.NewHostEmulator(opt) defer reset() // 测试 customLabel Level=high, threshold=high → 拦截 (true) t.Run("level high threshold high blocks", func(t *testing.T) { config := cfg.AISecurityConfig{ RiskAction: "block", ContentModerationLevelBar: "max", PromptAttackLevelBar: "max", SensitiveDataLevelBar: "S4", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", CustomLabelLevelBar: "high", } data := cfg.Data{ RiskLevel: "none", Detail: []cfg.Detail{ {Type: cfg.CustomLabelType, Level: "high"}, }, } require.Equal(t, cfg.RiskBlock, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "")) }) // 测试 customLabel Level=none, threshold=high → 不拦截 (false) t.Run("level none threshold high passes", func(t *testing.T) { config := cfg.AISecurityConfig{ RiskAction: "block", ContentModerationLevelBar: "max", PromptAttackLevelBar: "max", SensitiveDataLevelBar: "S4", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", CustomLabelLevelBar: "high", } data := cfg.Data{ RiskLevel: "none", Detail: []cfg.Detail{ {Type: cfg.CustomLabelType, Level: "none"}, }, } require.Equal(t, cfg.RiskPass, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "")) }) // 测试 customLabel Level=high, threshold=max → 不拦截 (false) t.Run("level high threshold max passes", func(t *testing.T) { config := cfg.AISecurityConfig{ RiskAction: "block", ContentModerationLevelBar: "max", PromptAttackLevelBar: "max", SensitiveDataLevelBar: "S4", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", CustomLabelLevelBar: "max", } data := cfg.Data{ RiskLevel: "none", Detail: []cfg.Detail{ {Type: cfg.CustomLabelType, Level: "high"}, }, } require.Equal(t, cfg.RiskPass, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "")) }) } func TestCustomLabelConfigIntegration(t *testing.T) { test.RunGoTest(t, func(t *testing.T) { // 测试 customLabelConfig 配置解析和消费者覆盖 t.Run("customLabel config with consumer override", func(t *testing.T) { host, status := test.NewTestHost(customLabelConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) config, err := host.GetMatchConfig() require.NoError(t, err) require.NotNil(t, config) securityConfig := config.(*cfg.AISecurityConfig) require.Equal(t, "high", securityConfig.CustomLabelLevelBar) require.Equal(t, "low", securityConfig.GetCustomLabelLevelBar("exact-user")) require.Equal(t, "medium", securityConfig.GetCustomLabelLevelBar("prefix-user")) require.Equal(t, "high", securityConfig.GetCustomLabelLevelBar("unknown-user")) }) }) } func TestRequestMasking(t *testing.T) { test.RunTest(t, func(t *testing.T) { // 测试请求阶段脱敏成功:riskAction=mask,API 返回 mask 建议,请求体被替换为脱敏内容 t.Run("request masking success", func(t *testing.T) { host, status := test.NewTestHost(maskConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "我的电话是13800138000"}]}` action := host.CallOnHttpRequestBody([]byte(body)) require.Equal(t, types.ActionPause, action) // API 返回 mask 建议,包含脱敏内容 securityResponse := `{ "Code": 200, "Message": "Success", "RequestId": "req-123", "Data": { "RiskLevel": "low", "Detail": [{ "Suggestion": "mask", "Type": "sensitiveData", "Level": "S3", "Result": [{ "Label": "phone_number", "Confidence": 99.0, "Ext": { "Desensitization": "我的电话是1**********", "SensitiveData": ["13800138000"] } }] }] } }` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) // 验证请求体被替换为脱敏内容 processedBody := host.GetRequestBody() require.NotNil(t, processedBody) content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() require.Equal(t, "我的电话是1**********", content) host.CompleteHttp() }) // 测试脱敏内容为空时回退到拦截 t.Run("empty desensitization falls back to block", func(t *testing.T) { host, status := test.NewTestHost(maskConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "我的电话是13800138000"}]}` action := host.CallOnHttpRequestBody([]byte(body)) require.Equal(t, types.ActionPause, action) // API 返回 mask 建议,但 Desensitization 为空 securityResponse := `{ "Code": 200, "Message": "Success", "RequestId": "req-123", "Data": { "RiskLevel": "low", "Detail": [{ "Suggestion": "mask", "Type": "sensitiveData", "Level": "S3", "Result": [{ "Label": "phone_number", "Confidence": 99.0, "Ext": { "Desensitization": "", "SensitiveData": ["13800138000"] } }] }] } }` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) // Desensitization 为空时应回退到拦截,SendHttpResponse 被调用 // 验证请求体未被修改(原始内容保持不变) processedBody := host.GetRequestBody() require.NotNil(t, processedBody) content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() require.Equal(t, "我的电话是13800138000", content) }) // 测试 riskAction=block 时 mask 建议按现有逻辑处理(向后兼容) t.Run("riskAction block keeps existing behavior for mask suggestion", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "我的电话是13800138000"}]}` action := host.CallOnHttpRequestBody([]byte(body)) require.Equal(t, types.ActionPause, action) // API 返回 mask 建议,但全局 riskAction 默认为 block securityResponse := `{ "Code": 200, "Message": "Success", "RequestId": "req-123", "Data": { "RiskLevel": "low", "Detail": [{ "Suggestion": "mask", "Type": "sensitiveData", "Level": "S2", "Result": [{ "Label": "phone_number", "Confidence": 99.0, "Ext": { "Desensitization": "我的电话是1**********", "SensitiveData": ["13800138000"] } }] }] } }` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) // riskAction=block 时,mask 建议按风险等级判断,low 级别应放行 action = host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) // 请求体不应被脱敏修改 processedBody := host.GetRequestBody() require.NotNil(t, processedBody) content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() require.Equal(t, "我的电话是13800138000", content) host.CompleteHttp() }) // 测试 riskAction=mask 时 block 建议优先拦截 t.Run("block suggestion takes priority over mask", func(t *testing.T) { host, status := test.NewTestHost(maskConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "违规内容"}]}` action := host.CallOnHttpRequestBody([]byte(body)) require.Equal(t, types.ActionPause, action) // API 返回 block 建议 securityResponse := `{ "Code": 200, "Message": "Success", "RequestId": "req-123", "Data": { "RiskLevel": "high", "Detail": [{ "Suggestion": "block", "Type": "contentModeration", "Level": "high" }] } }` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) // block 建议应拦截请求,请求体不应被修改 processedBody := host.GetRequestBody() require.NotNil(t, processedBody) content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() require.Equal(t, "违规内容", content) }) }) } // 测试配置:MCP + 脱敏模式 var mcpMaskConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": true, "action": "MultiModalGuard", "apiType": "mcp", "riskAction": "mask", "requestContentJsonPath": "params.arguments.input", "responseContentJsonPath": "content", "responseStreamContentJsonPath": "content", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, }) return data }() func TestIsRiskLevelAcceptable(t *testing.T) { // 需要 proxy-wasm host 环境,因为 evaluateRiskMultiModal 调用 proxywasm.LogInfof opt := proxytest.NewEmulatorOption().WithVMContext(&types.DefaultVMContext{}) _, reset := proxytest.NewHostEmulator(opt) defer reset() // 用例 1: riskAction=mask, Suggestion=mask → 应返回 true(mask 不应被视为不可接受) t.Run("mask action with mask suggestion is acceptable", func(t *testing.T) { config := cfg.AISecurityConfig{ RiskAction: "mask", ContentModerationLevelBar: "max", PromptAttackLevelBar: "max", SensitiveDataLevelBar: "S4", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", CustomLabelLevelBar: "max", } data := cfg.Data{ RiskLevel: "none", Detail: []cfg.Detail{ {Suggestion: "mask", Type: "sensitiveData", Level: "S2", Result: []cfg.Result{{Ext: cfg.Ext{Desensitization: "masked"}}}}, }, } require.True(t, cfg.IsRiskLevelAcceptable(cfg.MultiModalGuard, data, config, "")) }) // 用例 2: riskAction=mask, Suggestion=block → 应返回 false t.Run("mask action with block suggestion is not acceptable", func(t *testing.T) { config := cfg.AISecurityConfig{ RiskAction: "mask", ContentModerationLevelBar: "max", PromptAttackLevelBar: "max", SensitiveDataLevelBar: "S4", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", CustomLabelLevelBar: "max", } data := cfg.Data{ RiskLevel: "none", Detail: []cfg.Detail{ {Suggestion: "block", Type: "contentModeration", Level: "high"}, }, } require.False(t, cfg.IsRiskLevelAcceptable(cfg.MultiModalGuard, data, config, "")) }) // 用例 3: riskAction=mask, 无风险 → 应返回 true t.Run("mask action with no risk is acceptable", func(t *testing.T) { config := cfg.AISecurityConfig{ RiskAction: "mask", ContentModerationLevelBar: "max", PromptAttackLevelBar: "max", SensitiveDataLevelBar: "S4", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", CustomLabelLevelBar: "max", } data := cfg.Data{RiskLevel: "low"} require.True(t, cfg.IsRiskLevelAcceptable(cfg.MultiModalGuard, data, config, "")) }) // 用例 4: riskAction=block, Suggestion=mask, level 未超阈值 → 应返回 true(向后兼容) t.Run("block action with mask suggestion below threshold is acceptable", func(t *testing.T) { config := cfg.AISecurityConfig{ RiskAction: "block", ContentModerationLevelBar: "max", PromptAttackLevelBar: "max", SensitiveDataLevelBar: "S4", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", CustomLabelLevelBar: "max", } data := cfg.Data{ RiskLevel: "none", Detail: []cfg.Detail{ {Suggestion: "mask", Type: "sensitiveData", Level: "S2"}, }, } require.True(t, cfg.IsRiskLevelAcceptable(cfg.MultiModalGuard, data, config, "")) }) // 用例 5: riskAction=block, Suggestion=mask, level 超阈值 → 应返回 false(向后兼容) t.Run("block action with mask suggestion exceeding threshold is not acceptable", func(t *testing.T) { config := cfg.AISecurityConfig{ RiskAction: "block", ContentModerationLevelBar: "max", PromptAttackLevelBar: "max", SensitiveDataLevelBar: "S2", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", CustomLabelLevelBar: "max", } data := cfg.Data{ RiskLevel: "none", Detail: []cfg.Detail{ {Suggestion: "mask", Type: "sensitiveData", Level: "S2"}, }, } require.False(t, cfg.IsRiskLevelAcceptable(cfg.MultiModalGuard, data, config, "")) }) // 用例 6: TextModerationPlus, riskAction=mask → 不受影响 t.Run("TextModerationPlus not affected by riskAction mask", func(t *testing.T) { config := cfg.AISecurityConfig{ RiskAction: "mask", RiskLevelBar: "high", } data := cfg.Data{RiskLevel: "low"} require.True(t, cfg.IsRiskLevelAcceptable(cfg.TextModerationPlus, data, config, "")) }) } func TestMcpMaskNotBlock(t *testing.T) { test.RunTest(t, func(t *testing.T) { // 用例 7: MCP 请求, riskAction=mask, API 返回 Suggestion=mask → 应放行 t.Run("mcp request with mask suggestion should pass not block", func(t *testing.T) { host, status := test.NewTestHost(mcpMaskConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/mcp"}, {":method", "POST"}, }) body := `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"test","arguments":{"input":"我的电话是13800138000"}}}` action := host.CallOnHttpRequestBody([]byte(body)) require.Equal(t, types.ActionPause, action) // API 返回 mask 建议 securityResponse := `{ "Code": 200, "Message": "Success", "RequestId": "req-123", "Data": { "RiskLevel": "low", "Detail": [{ "Suggestion": "mask", "Type": "sensitiveData", "Level": "S2", "Result": [{ "Label": "phone_number", "Confidence": 99.0, "Ext": { "Desensitization": "我的电话是1**********", "SensitiveData": ["13800138000"] } }] }] } }` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) // 应放行而非拦截 action = host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) // 用例 8: MCP 响应, riskAction=mask, API 返回 Suggestion=mask → 应放行 t.Run("mcp response with mask suggestion should pass not block", func(t *testing.T) { host, status := test.NewTestHost(mcpMaskConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/mcp"}, {":method", "POST"}, }) host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) body := `{"content": "我的电话是13800138000"}` action := host.CallOnHttpResponseBody([]byte(body)) require.Equal(t, types.ActionPause, action) // API 返回 mask 建议 securityResponse := `{ "Code": 200, "Message": "Success", "RequestId": "req-123", "Data": { "RiskLevel": "low", "Detail": [{ "Suggestion": "mask", "Type": "sensitiveData", "Level": "S2", "Result": [{ "Label": "phone_number", "Confidence": 99.0, "Ext": { "Desensitization": "我的电话是1**********", "SensitiveData": ["13800138000"] } }] }] } }` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) // 应放行而非拦截 action = host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) // 用例 9: MCP 请求, riskAction=mask, API 返回 Suggestion=block → 应拦截 t.Run("mcp request with block suggestion should deny", func(t *testing.T) { host, status := test.NewTestHost(mcpMaskConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/mcp"}, {":method", "POST"}, }) body := `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"test","arguments":{"input":"违规内容"}}}` action := host.CallOnHttpRequestBody([]byte(body)) require.Equal(t, types.ActionPause, action) // API 返回 block 建议 securityResponse := `{ "Code": 200, "Message": "Success", "RequestId": "req-123", "Data": { "RiskLevel": "high", "Detail": [{ "Suggestion": "block", "Type": "contentModeration", "Level": "high" }] } }` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) // block 建议应拦截(SendHttpResponse 被调用,请求不会继续) // MCP handler 调用 SendHttpResponse 后不会 resume }) }) } // ============================================================================= // TC-PARSE: 配置解析与校验集成测试 // ============================================================================= // 测试配置:MultiModalGuard + 全局维度动作全为合法值 var dimensionActionValidConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": true, "action": "MultiModalGuard", "contentModerationAction": "block", "promptAttackAction": "block", "sensitiveDataAction": "mask", "maliciousUrlAction": "block", "modelHallucinationAction": "block", "customLabelAction": "block", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", }) return data }() // 测试配置:MultiModalGuard + 全局维度动作出现非法值 var dimensionActionInvalidConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": true, "action": "MultiModalGuard", "contentModerationAction": "allow", // 非法值 }) return data }() // 测试配置:MultiModalGuard + consumerRiskLevel 内维度动作非法 var consumerDimensionActionInvalidConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "action": "MultiModalGuard", "consumerRiskLevel": []map[string]interface{}{ { "name": "user-a", "matchType": "exact", "sensitiveDataAction": "deny", // 非法值 }, }, }) return data }() // 测试配置:TextModerationPlus + 配置了维度动作 var textModPlusDimensionActionConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "action": "TextModerationPlus", "sensitiveDataAction": "mask", "contentModerationAction": "block", }) return data }() // 测试配置:未配置任何动作字段 var noActionFieldConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, }) return data }() // TestTC_PARSE_001 MultiModalGuard + 全局维度动作全为合法值 => 启动成功 func TestTC_PARSE_001(t *testing.T) { test.RunGoTest(t, func(t *testing.T) { host, status := test.NewTestHost(dimensionActionValidConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) config, err := host.GetMatchConfig() require.NoError(t, err) securityConfig := config.(*cfg.AISecurityConfig) require.Equal(t, "block", securityConfig.ContentModerationAction) require.Equal(t, "block", securityConfig.PromptAttackAction) require.Equal(t, "mask", securityConfig.SensitiveDataAction) require.Equal(t, "block", securityConfig.MaliciousUrlAction) require.Equal(t, "block", securityConfig.ModelHallucinationAction) require.Equal(t, "block", securityConfig.CustomLabelAction) }) } // TestTC_PARSE_002 MultiModalGuard + 全局维度动作出现非法值 => 启动失败 func TestTC_PARSE_002(t *testing.T) { test.RunGoTest(t, func(t *testing.T) { host, status := test.NewTestHost(dimensionActionInvalidConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusFailed, status) }) } // TestTC_PARSE_003 MultiModalGuard + consumerRiskLevel 内维度动作非法 => 启动失败 func TestTC_PARSE_003(t *testing.T) { test.RunGoTest(t, func(t *testing.T) { host, status := test.NewTestHost(consumerDimensionActionInvalidConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusFailed, status) }) } // TestTC_PARSE_004 TextModerationPlus + 配置了维度动作 => 启动成功(字段忽略) func TestTC_PARSE_004(t *testing.T) { test.RunGoTest(t, func(t *testing.T) { host, status := test.NewTestHost(textModPlusDimensionActionConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) config, err := host.GetMatchConfig() require.NoError(t, err) securityConfig := config.(*cfg.AISecurityConfig) // 字段被解析但在运行时被忽略(不影响启动) require.Equal(t, "mask", securityConfig.SensitiveDataAction) require.Equal(t, "block", securityConfig.ContentModerationAction) }) } // TestTC_PARSE_005 未配置任何动作字段 => 默认 riskAction=block func TestTC_PARSE_005(t *testing.T) { test.RunGoTest(t, func(t *testing.T) { host, status := test.NewTestHost(noActionFieldConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) config, err := host.GetMatchConfig() require.NoError(t, err) securityConfig := config.(*cfg.AISecurityConfig) require.Equal(t, "block", securityConfig.RiskAction) require.Equal(t, "", securityConfig.ContentModerationAction) require.Equal(t, "", securityConfig.PromptAttackAction) require.Equal(t, "", securityConfig.SensitiveDataAction) require.Equal(t, "", securityConfig.MaliciousUrlAction) require.Equal(t, "", securityConfig.ModelHallucinationAction) require.Equal(t, "", securityConfig.CustomLabelAction) }) } // TestTC_PARSE_006 TextModerationPlus + 非法维度动作值 => 启动成功(需求 8.2) func TestTC_PARSE_006(t *testing.T) { // 非 MultiModalGuard 下配置了非法维度动作值,不应报错,应启动成功 invalidDimActionTextModConfig := func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "action": "TextModerationPlus", "contentModerationAction": "allow", // 非法值,但非 MultiModalGuard 下应忽略 "sensitiveDataAction": "deny", // 非法值 }) return data }() test.RunGoTest(t, func(t *testing.T) { host, status := test.NewTestHost(invalidDimActionTextModConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) }) } // TestTC_PARSE_007 TextModerationPlus + consumerRiskLevel 内非法维度动作值 => 启动成功(需求 8.2) func TestTC_PARSE_007(t *testing.T) { invalidConsumerDimActionTextModConfig := func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "action": "TextModerationPlus", "consumerRiskLevel": []map[string]interface{}{ { "name": "user-a", "matchType": "exact", "sensitiveDataAction": "invalid-value", // 非法值,但非 MultiModalGuard 下应忽略 }, }, }) return data }() test.RunGoTest(t, func(t *testing.T) { host, status := test.NewTestHost(invalidConsumerDimActionTextModConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) }) } // ============================================================================= // TC-REG: 回归测试 // ============================================================================= // 测试配置:历史仅 riskAction=block 的 MultiModalGuard 配置(无维度动作字段) var legacyRiskActionBlockConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": true, "action": "MultiModalGuard", "riskAction": "block", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, }) return data }() // 测试配置:历史仅 riskAction=mask 的 MultiModalGuard 配置(无维度动作字段) var legacyRiskActionMaskConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": true, "action": "MultiModalGuard", "riskAction": "mask", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, }) return data }() // TestTC_REG_004 历史仅 riskAction 配置的场景不回归 // 验证:仅配置 riskAction(不配置任何维度动作字段)时,新代码行为与历史一致 func TestTC_REG_004(t *testing.T) { test.RunTest(t, func(t *testing.T) { // 子用例 1: riskAction=block,请求安全检查通过(低风险)=> 放行 t.Run("legacy block config pass on low risk", func(t *testing.T) { host, status := test.NewTestHost(legacyRiskActionBlockConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "Hello"}]}` action := host.CallOnHttpRequestBody([]byte(body)) require.Equal(t, types.ActionPause, action) // API 返回低风险,无 Detail 触发 securityResponse := `{ "Code": 200, "Message": "Success", "RequestId": "req-reg-001", "Data": { "RiskLevel": "none", "Detail": [] } }` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) action = host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) // 子用例 2: riskAction=block,顶层 RiskLevel 超阈值 => 拦截 t.Run("legacy block config blocks on high risk level", func(t *testing.T) { host, status := test.NewTestHost(legacyRiskActionBlockConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "违规内容"}]}` action := host.CallOnHttpRequestBody([]byte(body)) require.Equal(t, types.ActionPause, action) // API 返回高风险,顶层 RiskLevel 超阈值 securityResponse := `{ "Code": 200, "Message": "Success", "RequestId": "req-reg-002", "Data": { "RiskLevel": "high", "Detail": [{ "Suggestion": "block", "Type": "contentModeration", "Level": "high" }] } }` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) // 拦截:请求体不应被修改 processedBody := host.GetRequestBody() require.NotNil(t, processedBody) content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() require.Equal(t, "违规内容", content) }) // 子用例 3: riskAction=block,Detail 有 mask 建议但 level 未超阈值 => 放行(block 模式忽略 mask) t.Run("legacy block config ignores mask suggestion below threshold", func(t *testing.T) { host, status := test.NewTestHost(legacyRiskActionBlockConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "我的电话是13800138000"}]}` action := host.CallOnHttpRequestBody([]byte(body)) require.Equal(t, types.ActionPause, action) // API 返回 mask 建议,但 riskAction=block 且 level 未超阈值 securityResponse := `{ "Code": 200, "Message": "Success", "RequestId": "req-reg-003", "Data": { "RiskLevel": "none", "Detail": [{ "Suggestion": "mask", "Type": "sensitiveData", "Level": "S2", "Result": [{ "Label": "phone_number", "Confidence": 99.0, "Ext": { "Desensitization": "我的电话是1**********", "SensitiveData": ["13800138000"] } }] }] } }` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) // riskAction=block 时,mask 建议不触发脱敏,level 未超阈值应放行 action = host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) // 请求体不应被脱敏修改 processedBody := host.GetRequestBody() require.NotNil(t, processedBody) content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() require.Equal(t, "我的电话是13800138000", content) host.CompleteHttp() }) // 子用例 4: riskAction=mask,Detail 有 mask 建议 => 脱敏替换 t.Run("legacy mask config applies desensitization", func(t *testing.T) { host, status := test.NewTestHost(legacyRiskActionMaskConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "我的电话是13800138000"}]}` action := host.CallOnHttpRequestBody([]byte(body)) require.Equal(t, types.ActionPause, action) // API 返回 mask 建议,riskAction=mask 应触发脱敏 securityResponse := `{ "Code": 200, "Message": "Success", "RequestId": "req-reg-004", "Data": { "RiskLevel": "none", "Detail": [{ "Suggestion": "mask", "Type": "sensitiveData", "Level": "S3", "Result": [{ "Label": "phone_number", "Confidence": 99.0, "Ext": { "Desensitization": "我的电话是1**********", "SensitiveData": ["13800138000"] } }] }] } }` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) // riskAction=mask 时,mask 建议应触发脱敏替换 processedBody := host.GetRequestBody() require.NotNil(t, processedBody) content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() require.Equal(t, "我的电话是1**********", content) host.CompleteHttp() }) // 子用例 5: riskAction=mask,Detail 有 block 建议 => 仍然拦截(block 优先) t.Run("legacy mask config still blocks on block suggestion", func(t *testing.T) { host, status := test.NewTestHost(legacyRiskActionMaskConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "违规内容"}]}` action := host.CallOnHttpRequestBody([]byte(body)) require.Equal(t, types.ActionPause, action) // API 返回 block 建议,即使 riskAction=mask 也应拦截 securityResponse := `{ "Code": 200, "Message": "Success", "RequestId": "req-reg-005", "Data": { "RiskLevel": "high", "Detail": [{ "Suggestion": "block", "Type": "contentModeration", "Level": "high" }] } }` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) // block 建议应拦截,请求体不应被修改 processedBody := host.GetRequestBody() require.NotNil(t, processedBody) content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() require.Equal(t, "违规内容", content) }) }) } func TestMultiModalGuardTextGenerationDeny(t *testing.T) { test.RunTest(t, func(t *testing.T) { // MultiModalGuard text_generation request deny → exercises multi_modal_guard/text/openai.go BuildDenyResponseBody path t.Run("multi modal guard text request deny returns blockedDetails", func(t *testing.T) { host, status := test.NewTestHost(multiModalGuardTextConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "trigger deny"}]}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mmg-text-deny", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for request deny") require.Contains(t, string(local.Data), "blockedDetails") }) // MultiModalGuard text_generation response deny → exercises common/text/openai.go HandleTextGenerationResponseBody BuildDenyResponseBody path t.Run("multi modal guard text response deny returns blockedDetails", func(t *testing.T) { host, status := test.NewTestHost(multiModalGuardTextConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) body := `{"choices": [{"message": {"role": "assistant", "content": "bad response content"}}]}` action := host.CallOnHttpResponseBody([]byte(body)) require.Equal(t, types.ActionPause, action) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mmg-resp-deny", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for response deny") require.Contains(t, string(local.Data), "blockedDetails") }) // MultiModalGuard text_generation request pass t.Run("multi modal guard text request pass", func(t *testing.T) { host, status := test.NewTestHost(multiModalGuardTextConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "Hello"}]}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mmg-pass", "Data": {"RiskLevel": "low"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) action := host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) }) } func TestMultiModalGuardImageGenerationDeny(t *testing.T) { test.RunTest(t, func(t *testing.T) { // OpenAI image generation request deny → exercises multi_modal_guard/image/openai.go BuildDenyResponseBody path t.Run("openai image request deny returns blockedDetails", func(t *testing.T) { host, status := test.NewTestHost(multiModalGuardImageOpenAIConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/images/generations"}, {":method", "POST"}, }) body := `{"prompt": "generate bad image"}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-openai-deny", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for OpenAI image request deny") require.Contains(t, string(local.Data), "blockedDetails") }) // OpenAI image generation request pass t.Run("openai image request pass", func(t *testing.T) { host, status := test.NewTestHost(multiModalGuardImageOpenAIConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/images/generations"}, {":method", "POST"}, }) body := `{"prompt": "a cute cat"}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-pass", "Data": {"RiskLevel": "low"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) action := host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) // Qwen image generation request deny → exercises multi_modal_guard/image/qwen.go BuildDenyResponseBody path t.Run("qwen image request deny returns blockedDetails", func(t *testing.T) { host, status := test.NewTestHost(multiModalGuardImageQwenConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/images/generations"}, {":method", "POST"}, }) body := `{"input": {"prompt": "generate bad image"}}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-qwen-deny", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for Qwen image request deny") require.Contains(t, string(local.Data), "blockedDetails") }) // Qwen image generation request pass t.Run("qwen image request pass", func(t *testing.T) { host, status := test.NewTestHost(multiModalGuardImageQwenConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/images/generations"}, {":method", "POST"}, }) body := `{"input": {"prompt": "a cute cat"}}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-qwen-pass", "Data": {"RiskLevel": "low"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) action := host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) }) } func TestMCPRequestDeny(t *testing.T) { test.RunTest(t, func(t *testing.T) { // MCP request deny → exercises multi_modal_guard/mcp/mcp.go HandleMcpRequestBody BuildDenyResponseBody path t.Run("mcp request deny returns blockedDetails", func(t *testing.T) { host, status := test.NewTestHost(mcpRequestConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/mcp/call"}, {":method", "POST"}, }) body := `{"method": "tools/call", "params": {"arguments": "bad request content"}}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mcp-deny", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for MCP request deny") require.Contains(t, string(local.Data), "blockedDetails") }) // MCP request pass t.Run("mcp request pass", func(t *testing.T) { host, status := test.NewTestHost(mcpRequestConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/mcp/call"}, {":method", "POST"}, }) body := `{"method": "tools/call", "params": {"arguments": "safe content"}}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mcp-pass", "Data": {"RiskLevel": "low"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) action := host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) // MCP request skip non-tool-call method t.Run("mcp request skip non-tool-call", func(t *testing.T) { host, status := test.NewTestHost(mcpRequestConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/mcp/call"}, {":method", "POST"}, }) body := `{"method": "resources/list", "params": {}}` action := host.CallOnHttpRequestBody([]byte(body)) require.Equal(t, types.ActionContinue, action) }) }) } func TestTextModerationPlusResponseDeny(t *testing.T) { test.RunTest(t, func(t *testing.T) { // TextModerationPlus response deny → exercises text_moderation_plus/text (via common/text) BuildDenyResponseBody response path t.Run("text moderation plus response deny returns blockedDetails", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) body := `{"choices": [{"message": {"role": "assistant", "content": "bad response"}}]}` action := host.CallOnHttpResponseBody([]byte(body)) require.Equal(t, types.ActionPause, action) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-tmp-resp-deny", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for response deny") require.Contains(t, string(local.Data), "blockedDetails") // Verify OpenAI completion shape wrapper type openAIChatCompletion struct { Choices []struct { Message struct { Content string `json:"content"` } `json:"message"` } `json:"choices"` } var outer openAIChatCompletion require.NoError(t, json.Unmarshal(local.Data, &outer)) require.Len(t, outer.Choices, 1) var deny cfg.DenyResponseBody require.NoError(t, json.Unmarshal([]byte(outer.Choices[0].Message.Content), &deny)) require.Equal(t, 200, deny.Code) require.NotEmpty(t, deny.BlockedDetails) }) }) } func TestBuildDenyResponseBody(t *testing.T) { makeConfig := func(contentBar, promptBar string) cfg.AISecurityConfig { return cfg.AISecurityConfig{ ContentModerationLevelBar: contentBar, PromptAttackLevelBar: promptBar, SensitiveDataLevelBar: "S4", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", CustomLabelLevelBar: "max", RiskAction: "block", Action: cfg.MultiModalGuard, } } t.Run("code equals response.Code", func(t *testing.T) { resp := cfg.Response{ Code: 200, RequestId: "req-123", Data: cfg.Data{}, } body, err := cfg.BuildDenyResponseBody(resp, makeConfig("high", "high"), "") require.NoError(t, err) var result cfg.DenyResponseBody require.NoError(t, json.Unmarshal(body, &result)) require.Equal(t, 200, result.Code) }) t.Run("blockedDetails from Data.Detail", func(t *testing.T) { resp := cfg.Response{ Code: 200, RequestId: "req-456", Data: cfg.Data{ Detail: []cfg.Detail{ {Type: cfg.ContentModerationType, Level: "high", Suggestion: "block"}, {Type: cfg.PromptAttackType, Level: "low", Suggestion: "none"}, }, }, } config := makeConfig("high", "high") body, err := cfg.BuildDenyResponseBody(resp, config, "") require.NoError(t, err) var result cfg.DenyResponseBody require.NoError(t, json.Unmarshal(body, &result)) // only the contentModeration entry meets the "high" bar; promptAttack at "low" does not require.Len(t, result.BlockedDetails, 1) require.Equal(t, cfg.ContentModerationType, result.BlockedDetails[0].Type) require.Equal(t, "high", result.BlockedDetails[0].Level) }) t.Run("blockedDetails includes explicit block suggestion below threshold", func(t *testing.T) { resp := cfg.Response{ Code: 200, RequestId: "req-suggestion-block", Data: cfg.Data{ Detail: []cfg.Detail{ {Type: cfg.SensitiveDataType, Level: "S3", Suggestion: "block"}, }, }, } config := makeConfig("high", "high") config.SensitiveDataLevelBar = "S4" body, err := cfg.BuildDenyResponseBody(resp, config, "") require.NoError(t, err) var result cfg.DenyResponseBody require.NoError(t, json.Unmarshal(body, &result)) require.Len(t, result.BlockedDetails, 1) require.Equal(t, cfg.SensitiveDataType, result.BlockedDetails[0].Type) require.Equal(t, "S3", result.BlockedDetails[0].Level) }) t.Run("blockedDetails includes customLabel when threshold exceeded", func(t *testing.T) { resp := cfg.Response{ Code: 200, RequestId: "req-custom-label", Data: cfg.Data{ Detail: []cfg.Detail{ {Type: cfg.CustomLabelType, Level: "high", Suggestion: "none"}, }, }, } config := makeConfig("high", "high") config.CustomLabelLevelBar = "high" body, err := cfg.BuildDenyResponseBody(resp, config, "") require.NoError(t, err) var result cfg.DenyResponseBody require.NoError(t, json.Unmarshal(body, &result)) require.Len(t, result.BlockedDetails, 1) require.Equal(t, cfg.CustomLabelType, result.BlockedDetails[0].Type) require.Equal(t, "high", result.BlockedDetails[0].Level) }) t.Run("blockedDetails fallback from RiskLevel when Detail is empty", func(t *testing.T) { resp := cfg.Response{ Code: 200, RequestId: "req-789", Data: cfg.Data{ RiskLevel: "high", // Detail deliberately empty }, } config := makeConfig("high", "high") body, err := cfg.BuildDenyResponseBody(resp, config, "") require.NoError(t, err) var result cfg.DenyResponseBody require.NoError(t, json.Unmarshal(body, &result)) require.NotEmpty(t, result.BlockedDetails, "expected fallback detail from RiskLevel") require.Equal(t, cfg.ContentModerationType, result.BlockedDetails[0].Type) require.Equal(t, "high", result.BlockedDetails[0].Level) }) t.Run("blockedDetails fallback from AttackLevel when Detail is empty", func(t *testing.T) { resp := cfg.Response{ Code: 200, RequestId: "req-abc", Data: cfg.Data{ AttackLevel: "high", // Detail deliberately empty }, } config := makeConfig("high", "high") body, err := cfg.BuildDenyResponseBody(resp, config, "") require.NoError(t, err) var result cfg.DenyResponseBody require.NoError(t, json.Unmarshal(body, &result)) require.NotEmpty(t, result.BlockedDetails, "expected fallback detail from AttackLevel") require.Equal(t, cfg.PromptAttackType, result.BlockedDetails[0].Type) require.Equal(t, "high", result.BlockedDetails[0].Level) }) t.Run("blockedDetails empty when risk levels below threshold", func(t *testing.T) { resp := cfg.Response{ Code: 200, RequestId: "req-def", Data: cfg.Data{ RiskLevel: "low", AttackLevel: "low", }, } // threshold is "high", so "low" must not produce fallback entries config := makeConfig("high", "high") body, err := cfg.BuildDenyResponseBody(resp, config, "") require.NoError(t, err) var result cfg.DenyResponseBody require.NoError(t, json.Unmarshal(body, &result)) require.Empty(t, result.BlockedDetails) }) } func TestBuildDenyResponseBody_WithDenyMessage(t *testing.T) { config := cfg.AISecurityConfig{ ContentModerationLevelBar: "high", PromptAttackLevelBar: "high", SensitiveDataLevelBar: "S4", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", CustomLabelLevelBar: "max", RiskAction: "block", Action: cfg.MultiModalGuard, DenyMessage: "很抱歉,我无法回答您的问题", } resp := cfg.Response{ Code: 200, Data: cfg.Data{ Detail: []cfg.Detail{ {Type: cfg.ContentModerationType, Level: "high", Suggestion: "block"}, }, }, } body, err := cfg.BuildDenyResponseBody(resp, config, "") require.NoError(t, err) var result cfg.DenyResponseBody require.NoError(t, json.Unmarshal(body, &result)) require.Equal(t, "很抱歉,我无法回答您的问题", result.DenyMessage) } func TestBuildDenyResponseBody_WithoutDenyMessage(t *testing.T) { config := cfg.AISecurityConfig{ ContentModerationLevelBar: "high", PromptAttackLevelBar: "high", SensitiveDataLevelBar: "S4", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", CustomLabelLevelBar: "max", RiskAction: "block", Action: cfg.MultiModalGuard, } resp := cfg.Response{ Code: 200, Data: cfg.Data{ Detail: []cfg.Detail{ {Type: cfg.ContentModerationType, Level: "high", Suggestion: "block"}, }, }, } body, err := cfg.BuildDenyResponseBody(resp, config, "") require.NoError(t, err) require.NotContains(t, string(body), "denyMessage") } func TestBuildDenyResponseBody_BlockedDetailsOnlyTypeAndLevel(t *testing.T) { config := cfg.AISecurityConfig{ ContentModerationLevelBar: "high", PromptAttackLevelBar: "high", SensitiveDataLevelBar: "S4", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", CustomLabelLevelBar: "max", RiskAction: "block", Action: cfg.MultiModalGuard, } resp := cfg.Response{ Code: 200, Data: cfg.Data{ Detail: []cfg.Detail{ {Type: cfg.ContentModerationType, Level: "high", Suggestion: "block", Result: []cfg.Result{{Label: "violence"}}}, {Type: cfg.PromptAttackType, Level: "high", Suggestion: "block", Result: []cfg.Result{{Label: "injection"}}}, }, }, } body, err := cfg.BuildDenyResponseBody(resp, config, "") require.NoError(t, err) var raw map[string]interface{} require.NoError(t, json.Unmarshal(body, &raw)) details := raw["blockedDetails"].([]interface{}) require.Len(t, details, 2) for _, entry := range details { m := entry.(map[string]interface{}) require.Len(t, m, 2, "each blockedDetail entry should have exactly 2 keys (type and level)") require.Contains(t, m, "type") require.Contains(t, m, "level") } } func TestBuildDenyResponseBody_CodeField(t *testing.T) { config := cfg.AISecurityConfig{ ContentModerationLevelBar: "high", PromptAttackLevelBar: "high", SensitiveDataLevelBar: "S4", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", CustomLabelLevelBar: "max", RiskAction: "block", Action: cfg.MultiModalGuard, } resp := cfg.Response{ Code: 200, Data: cfg.Data{}, } body, err := cfg.BuildDenyResponseBody(resp, config, "") require.NoError(t, err) var result cfg.DenyResponseBody require.NoError(t, json.Unmarshal(body, &result)) require.Equal(t, 200, result.Code) } func TestBuildDenyResponseBody_NoRequestId(t *testing.T) { config := cfg.AISecurityConfig{ ContentModerationLevelBar: "high", PromptAttackLevelBar: "high", SensitiveDataLevelBar: "S4", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", CustomLabelLevelBar: "max", RiskAction: "block", Action: cfg.MultiModalGuard, } resp := cfg.Response{ Code: 200, RequestId: "req-should-not-appear", Data: cfg.Data{}, } body, err := cfg.BuildDenyResponseBody(resp, config, "") require.NoError(t, err) require.NotContains(t, string(body), "requestId") } func TestBuildDenyResponseBody_FallbackSynthesis(t *testing.T) { config := cfg.AISecurityConfig{ ContentModerationLevelBar: "high", PromptAttackLevelBar: "high", SensitiveDataLevelBar: "S4", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", CustomLabelLevelBar: "max", RiskAction: "block", Action: cfg.MultiModalGuard, } resp := cfg.Response{ Code: 200, Data: cfg.Data{ RiskLevel: "high", // No Detail entries — triggers fallback synthesis }, } body, err := cfg.BuildDenyResponseBody(resp, config, "") require.NoError(t, err) var raw map[string]interface{} require.NoError(t, json.Unmarshal(body, &raw)) details := raw["blockedDetails"].([]interface{}) require.NotEmpty(t, details, "expected fallback synthesized entries") for _, entry := range details { m := entry.(map[string]interface{}) require.Len(t, m, 2, "fallback blockedDetail entry should have exactly 2 keys (type and level)") require.Contains(t, m, "type") require.Contains(t, m, "level") } } // ============================================================================= // TC-COVER: 覆盖率补充测试 // ============================================================================= // TestMultiModalGuardStreamDeny 覆盖 openai.go RiskBlock 分支中 stream 响应格式路径 func TestMultiModalGuardStreamDeny(t *testing.T) { test.RunTest(t, func(t *testing.T) { t.Run("stream request deny returns SSE format", func(t *testing.T) { host, status := test.NewTestHost(multiModalGuardTextConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 请求体中包含 stream=true,触发 SSE 响应格式 body := `{"messages": [{"role": "user", "content": "trigger deny"}], "stream": true}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-stream-deny", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for stream deny") require.Contains(t, string(local.Data), "blockedDetails") // 验证 SSE content-type foundSSE := false for _, h := range local.Headers { if h[0] == "content-type" { require.Equal(t, "text/event-stream;charset=UTF-8", h[1]) foundSSE = true } } require.True(t, foundSSE, "expected SSE content-type header") }) }) } // TestMultiModalGuardProtocolOriginalDeny 覆盖 openai.go RiskBlock 分支中 ProtocolOriginal 路径 func TestMultiModalGuardProtocolOriginalDeny(t *testing.T) { test.RunTest(t, func(t *testing.T) { t.Run("protocol original deny returns raw blockedDetails JSON", func(t *testing.T) { host, status := test.NewTestHost(protocolOriginalConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "trigger deny"}]}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-proto-orig", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for protocol original deny") require.Contains(t, string(local.Data), "blockedDetails") // ProtocolOriginal 直接返回 JSON,不包装 OpenAI 格式 for _, h := range local.Headers { if h[0] == "content-type" { require.Equal(t, "application/json", h[1]) } } // 响应体是原始 blockedDetails JSON,不含 OpenAI 包装 require.False(t, gjson.GetBytes(local.Data, "choices").Exists(), "should not wrap in OpenAI format") }) }) } // TestMultiModalGuardDenyWithAdvice 覆盖 openai.go RiskBlock 分支中 Advice != nil 路径 func TestMultiModalGuardDenyWithAdvice(t *testing.T) { test.RunTest(t, func(t *testing.T) { t.Run("deny with advice sets riskLabel and riskWords attributes", func(t *testing.T) { host, status := test.NewTestHost(multiModalGuardTextConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "trigger deny with advice"}]}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) // 包含 Advice 和 Result 的安全服务响应 securityResponse := `{ "Code": 200, "Message": "Success", "RequestId": "req-advice-deny", "Data": { "RiskLevel": "high", "Result": [{"Label": "porn", "RiskWords": "bad-word"}], "Advice": [{"Answer": "blocked", "HitLabel": "porn"}], "Detail": [{"Suggestion": "block", "Type": "contentModeration", "Level": "high"}] } }` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse") require.Contains(t, string(local.Data), "blockedDetails") }) }) } // TestMultiChunkMasking 覆盖 openai.go 中 RiskPass + hasMasked 路径 // 场景:内容超过 LengthLimit(1800),第一 chunk 触发 RiskMask 脱敏替换,第二 chunk RiskPass func TestMultiChunkMasking(t *testing.T) { test.RunTest(t, func(t *testing.T) { t.Run("multi chunk masking with pass on second chunk", func(t *testing.T) { host, status := test.NewTestHost(maskConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 生成超过 LengthLimit (1800) 的内容 longContent := strings.Repeat("a", 2000) body := `{"messages": [{"role": "user", "content": "` + longContent + `"}]}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) // 第一个 chunk (1800 chars):返回 mask 建议及脱敏内容 maskedChunk := strings.Repeat("b", 1800) securityResponse1 := `{ "Code": 200, "Message": "Success", "RequestId": "req-chunk-1", "Data": { "RiskLevel": "none", "Detail": [{ "Suggestion": "mask", "Type": "sensitiveData", "Level": "S3", "Result": [{"Label": "phone", "Confidence": 99.0, "Ext": {"Desensitization": "` + maskedChunk + `"}}] }] } }` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse1)) // 第二个 chunk (200 chars):返回 pass(无风险) securityResponse2 := `{"Code": 200, "Message": "Success", "RequestId": "req-chunk-2", "Data": {"RiskLevel": "none", "Detail": []}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse2)) // 验证请求体被替换为脱敏内容 processedBody := host.GetRequestBody() require.NotNil(t, processedBody) content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() // 期望 = 脱敏后的第一 chunk (1800 'b') + 原始第二 chunk (200 'a') expectedContent := maskedChunk + strings.Repeat("a", 200) require.Equal(t, expectedContent, content) host.CompleteHttp() }) // 覆盖 RiskMask 成功完成路径(单 chunk 内容刚好 <= LengthLimit,RiskMask 后立即完成) // 该路径在 RiskMask 分支中 contentIndex >= len(maskedContent) 的子路径 t.Run("single chunk mask completes in RiskMask branch", func(t *testing.T) { host, status := test.NewTestHost(maskConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "我的银行卡号是6222021234567890"}]}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{ "Code": 200, "Message": "Success", "RequestId": "req-single-mask", "Data": { "RiskLevel": "none", "Detail": [{ "Suggestion": "mask", "Type": "sensitiveData", "Level": "S3", "Result": [{"Label": "bank_card", "Confidence": 99.0, "Ext": {"Desensitization": "我的银行卡号是6222************"}}] }] } }` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) processedBody := host.GetRequestBody() require.NotNil(t, processedBody) content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String() require.Equal(t, "我的银行卡号是6222************", content) host.CompleteHttp() }) }) } // TestMultiModalGuardMaskStreamDeny 覆盖 openai.go RiskMask 空脱敏 fallthrough 到 RiskBlock 的 stream 路径 func TestMultiModalGuardMaskStreamDeny(t *testing.T) { test.RunTest(t, func(t *testing.T) { t.Run("mask with empty desensitization falls through to block stream format", func(t *testing.T) { host, status := test.NewTestHost(maskConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // stream=true 使 block 走 SSE 格式 body := `{"messages": [{"role": "user", "content": "敏感内容"}], "stream": true}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) // 返回 mask 建议但脱敏内容为空 → fallthrough 到 RiskBlock securityResponse := `{ "Code": 200, "Message": "Success", "RequestId": "req-mask-stream-deny", "Data": { "RiskLevel": "none", "Detail": [{ "Suggestion": "mask", "Type": "sensitiveData", "Level": "S3", "Result": [{"Label": "phone", "Confidence": 99.0, "Ext": {"Desensitization": ""}}] }] } }` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse after mask fallthrough to block") // 验证是 SSE 格式 foundSSE := false for _, h := range local.Headers { if h[0] == "content-type" { require.Equal(t, "text/event-stream;charset=UTF-8", h[1]) foundSSE = true } } require.True(t, foundSSE, "expected SSE content-type for stream deny") }) }) }