// Copyright (c) 2024 Alibaba Group Holding Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "encoding/json" "testing" cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/test" "github.com/stretchr/testify/require" ) // 测试配置:基础安全配置 var basicConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": true, "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, "bufferLimit": 1000, }) return data }() // 测试配置:仅检查请求 var requestOnlyConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": false, "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 1000, "bufferLimit": 500, }) return data }() // 测试配置:缺少必需字段 var missingRequiredConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "accessKey": "test-ak", "secretKey": "test-sk", // 故意缺少必需字段:serviceName, servicePort, serviceHost }) return data }() // 测试配置:缺少服务配置字段 var missingServiceConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": true, // 缺少 serviceName, servicePort, serviceHost }) return data }() // 测试配置:缺少认证字段 var missingAuthConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "checkRequest": true, "checkResponse": true, // 缺少 accessKey, secretKey }) return data }() // 测试配置:消费者级别特殊配置 var consumerSpecificConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": false, "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "maliciousUrlLevelBar": "high", "modelHallucinationLevelBar": "high", "timeout": 1000, "bufferLimit": 500, "consumerRequestCheckService": map[string]interface{}{ "name": "aaa", "matchType": "exact", "requestCheckService": "llm_query_moderation_1", }, "consumerResponseCheckService": map[string]interface{}{ "name": "bbb", "matchType": "prefix", "responseCheckService": "llm_response_moderation_1", }, "consumerRiskLevel": map[string]interface{}{ "name": "ccc.*", "matchType": "regexp", "maliciousUrlLevelBar": "low", }, }) return data }() // 测试配置:MCP配置 var mcpConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": false, "checkResponse": true, "action": "MultiModalGuard", "apiType": "mcp", "responseContentJsonPath": "content", "responseStreamContentJsonPath": "content", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, }) return data }() // 测试配置:MCP配置(启用请求检查) var mcpRequestConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": false, "action": "MultiModalGuard", "apiType": "mcp", "requestContentJsonPath": "params.arguments", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, }) return data }() // 测试配置:MultiModalGuard 文本生成 var multiModalGuardTextConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": true, "action": "MultiModalGuard", "apiType": "text_generation", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, "bufferLimit": 1000, }) return data }() // 测试配置:MultiModalGuard OpenAI 图像生成 var multiModalGuardImageOpenAIConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": true, "action": "MultiModalGuard", "apiType": "image_generation", "providerType": "openai", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, }) return data }() // 测试配置:MultiModalGuard Qwen 图像生成 var multiModalGuardImageQwenConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "serviceName": "security-service", "servicePort": 8080, "serviceHost": "security.example.com", "accessKey": "test-ak", "secretKey": "test-sk", "checkRequest": true, "checkResponse": true, "action": "MultiModalGuard", "apiType": "image_generation", "providerType": "qwen", "contentModerationLevelBar": "high", "promptAttackLevelBar": "high", "sensitiveDataLevelBar": "S3", "timeout": 2000, }) return data }() func TestParseConfig(t *testing.T) { test.RunGoTest(t, func(t *testing.T) { // 测试基础配置解析 t.Run("basic config", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) config, err := host.GetMatchConfig() require.NoError(t, err) require.NotNil(t, config) securityConfig := config.(*cfg.AISecurityConfig) require.Equal(t, "test-ak", securityConfig.AK) require.Equal(t, "test-sk", securityConfig.SK) require.Equal(t, true, securityConfig.CheckRequest) require.Equal(t, true, securityConfig.CheckResponse) require.Equal(t, "high", securityConfig.ContentModerationLevelBar) require.Equal(t, "high", securityConfig.PromptAttackLevelBar) require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar) require.Equal(t, uint32(2000), securityConfig.Timeout) require.Equal(t, 1000, securityConfig.BufferLimit) }) // 测试仅检查请求的配置 t.Run("request only config", func(t *testing.T) { host, status := test.NewTestHost(requestOnlyConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) config, err := host.GetMatchConfig() require.NoError(t, err) require.NotNil(t, config) securityConfig := config.(*cfg.AISecurityConfig) require.Equal(t, true, securityConfig.CheckRequest) require.Equal(t, false, securityConfig.CheckResponse) require.Equal(t, "high", securityConfig.ContentModerationLevelBar) require.Equal(t, "high", securityConfig.PromptAttackLevelBar) require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar) }) // 测试缺少必需字段的配置 t.Run("missing required config", func(t *testing.T) { host, status := test.NewTestHost(missingRequiredConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusFailed, status) }) // 测试缺少服务配置字段 t.Run("missing service config", func(t *testing.T) { host, status := test.NewTestHost(missingServiceConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusFailed, status) }) // 测试缺少认证字段 t.Run("missing auth config", func(t *testing.T) { host, status := test.NewTestHost(missingAuthConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusFailed, status) }) // 测试消费者级别配置 t.Run("consumer specific config", func(t *testing.T) { host, status := test.NewTestHost(consumerSpecificConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) config, err := host.GetMatchConfig() require.NoError(t, err) require.NotNil(t, config) securityConfig := config.(*cfg.AISecurityConfig) require.Equal(t, "llm_query_moderation", securityConfig.GetRequestCheckService("aaaa")) require.Equal(t, "llm_query_moderation_1", securityConfig.GetRequestCheckService("aaa")) require.Equal(t, "llm_response_moderation", securityConfig.GetResponseCheckService("bb")) require.Equal(t, "llm_response_moderation_1", securityConfig.GetResponseCheckService("bbb-prefix-test")) require.Equal(t, "high", securityConfig.GetMaliciousUrlLevelBar("cc")) require.Equal(t, "low", securityConfig.GetMaliciousUrlLevelBar("ccc-regexp-test")) }) }) } func TestOnHttpRequestHeaders(t *testing.T) { test.RunTest(t, func(t *testing.T) { // 测试启用请求检查的情况 t.Run("request checking enabled", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 设置请求头 action := host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 应该返回ActionContinue require.Equal(t, types.ActionContinue, action) }) // 测试禁用请求检查的情况 t.Run("request checking disabled", func(t *testing.T) { host, status := test.NewTestHost(requestOnlyConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 设置请求头 action := host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 应该返回ActionContinue require.Equal(t, types.ActionContinue, action) }) }) } func TestOnHttpRequestBody(t *testing.T) { test.RunTest(t, func(t *testing.T) { // 测试请求体安全检查通过 t.Run("request body security check pass", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 先设置请求头 host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 设置请求体 body := `{"messages": [{"role": "user", "content": "Hello, how are you?"}]}` action := host.CallOnHttpRequestBody([]byte(body)) // 应该返回ActionPause,等待安全检查结果 require.Equal(t, types.ActionPause, action) // 模拟安全检查服务响应(通过) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) action = host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) // 测试空请求内容 t.Run("empty request content", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 先设置请求头 host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 设置空内容的请求体 body := `{"messages": [{"role": "user", "content": ""}]}` action := host.CallOnHttpRequestBody([]byte(body)) // 空内容应该直接通过 require.Equal(t, types.ActionContinue, action) }) // TextModerationPlus(默认 action,含 agent/OpenAI 形态)请求拦截应返回 choices[0].message.content 内的 blockedDetails JSON t.Run("text moderation plus request deny returns blockedDetails in openai completion shape", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "trigger deny"}]}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-tmp-deny", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for request deny") require.Contains(t, string(local.Data), "blockedDetails") require.Contains(t, string(local.Data), "req-tmp-deny") type openAIChatCompletion struct { Choices []struct { Message struct { Content string `json:"content"` } `json:"message"` } `json:"choices"` } var outer openAIChatCompletion require.NoError(t, json.Unmarshal(local.Data, &outer)) require.Len(t, outer.Choices, 1) var deny cfg.DenyResponseBody require.NoError(t, json.Unmarshal([]byte(outer.Choices[0].Message.Content), &deny)) require.Equal(t, "req-tmp-deny", deny.RequestId) require.Equal(t, 200, deny.GuardCode) require.NotEmpty(t, deny.BlockedDetails) require.Equal(t, cfg.ContentModerationType, deny.BlockedDetails[0].Type) }) }) } func TestOnHttpResponseHeaders(t *testing.T) { test.RunTest(t, func(t *testing.T) { // 测试启用响应检查的情况 t.Run("response checking enabled", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 先设置请求头 host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 设置响应头 action := host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) // 应该返回HeaderStopIteration require.Equal(t, types.HeaderStopIteration, action) }) // 测试禁用响应检查的情况 t.Run("response checking disabled", func(t *testing.T) { host, status := test.NewTestHost(requestOnlyConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 先设置请求头 host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 设置响应头 action := host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) // 应该返回ActionContinue require.Equal(t, types.ActionContinue, action) }) // 测试非200状态码 t.Run("non-200 status code", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 先设置请求头 host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 设置非200响应头 action := host.CallOnHttpResponseHeaders([][2]string{ {":status", "500"}, {"content-type", "application/json"}, }) // 应该返回ActionContinue require.Equal(t, types.ActionContinue, action) }) }) } func TestOnHttpResponseBody(t *testing.T) { test.RunTest(t, func(t *testing.T) { // 测试响应体安全检查通过 t.Run("response body security check pass", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 先设置请求头 host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 设置响应头 host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) // 设置响应体 body := `{"choices": [{"message": {"role": "assistant", "content": "Hello, how can I help you?"}}]}` action := host.CallOnHttpResponseBody([]byte(body)) // 应该返回ActionPause,等待安全检查结果 require.Equal(t, types.ActionPause, action) // 模拟安全检查服务响应(通过) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) action = host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) // 测试空响应内容 t.Run("empty response content", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) // 先设置请求头 host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) // 设置响应头 host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) // 设置空内容的响应体 body := `{"choices": [{"message": {"role": "assistant", "content": ""}}]}` action := host.CallOnHttpResponseBody([]byte(body)) // 空内容应该直接通过 require.Equal(t, types.ActionContinue, action) }) }) } func TestMCP(t *testing.T) { test.RunTest(t, func(t *testing.T) { // Test MCP Response Body Check - Pass t.Run("mcp response body security check pass", func(t *testing.T) { host, status := test.NewTestHost(mcpConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, {"x-mse-consumer", "test-user"}, }) host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) // body content matching responseContentJsonPath="content" body := `{"content": "Hello world"}` action := host.CallOnHttpResponseBody([]byte(body)) require.Equal(t, types.ActionPause, action) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) action = host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) // Test MCP Response Body Check - Deny t.Run("mcp response body security check deny", func(t *testing.T) { host, status := test.NewTestHost(mcpConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) body := `{"content": "Bad content"}` action := host.CallOnHttpResponseBody([]byte(body)) require.Equal(t, types.ActionPause, action) // High Risk securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) // Verify it was replaced with DenyResponse // Can't easily verify the replaced body content with current test wrapper but can check action // Since plugin calls SendHttpResponse, execution stops or changes. // mcp.go uses SendHttpResponse(..., DenyResponse, -1) which means it ends the stream. // We can check if GetHttpStreamAction is ActionPause (since it did send a response) or something else. // Actually SendHttpResponse in proxy-wasm usually terminates further processing of the original stream. }) // Test MCP Streaming Response Body Check - Pass t.Run("mcp streaming response body security check pass", func(t *testing.T) { host, status := test.NewTestHost(mcpConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "text/event-stream"}, }) // streaming chunk // config uses "content" key chunk := []byte(`data: {"content": "Hello"}` + "\n\n") // This calls OnHttpStreamingResponseBody -> mcp.HandleMcpStreamingResponseBody // It should push buffer and make call host.CallOnHttpStreamingResponseBody(chunk, false) // Action assertion removed as it returns an internal value 3 securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) }) // Test MCP Streaming Response Body Check - Deny t.Run("mcp streaming response body security check deny", func(t *testing.T) { host, status := test.NewTestHost(mcpConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "text/event-stream"}, }) chunk := []byte(`data: {"content": "Bad"}` + "\n\n") host.CallOnHttpStreamingResponseBody(chunk, false) // High Risk securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) // It injects DenySSEResponse. }) }) } func TestRiskLevelFunctions(t *testing.T) { // 测试风险等级转换函数 t.Run("risk level conversion", func(t *testing.T) { require.Equal(t, 4, cfg.LevelToInt(cfg.MaxRisk)) require.Equal(t, 3, cfg.LevelToInt(cfg.HighRisk)) require.Equal(t, 2, cfg.LevelToInt(cfg.MediumRisk)) require.Equal(t, 1, cfg.LevelToInt(cfg.LowRisk)) require.Equal(t, 0, cfg.LevelToInt(cfg.NoRisk)) require.Equal(t, -1, cfg.LevelToInt("invalid")) }) // 测试风险等级比较 t.Run("risk level comparison", func(t *testing.T) { require.True(t, cfg.LevelToInt(cfg.HighRisk) >= cfg.LevelToInt(cfg.MediumRisk)) require.True(t, cfg.LevelToInt(cfg.MediumRisk) >= cfg.LevelToInt(cfg.LowRisk)) require.True(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.NoRisk)) require.False(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.HighRisk)) }) } func TestUtilityFunctions(t *testing.T) { // 测试十六进制ID生成函数 t.Run("hex id generation", func(t *testing.T) { id, err := utils.GenerateHexID(16) require.NoError(t, err) require.Len(t, id, 16) require.Regexp(t, "^[0-9a-f]+$", id) }) // 测试随机ID生成函数 t.Run("random id generation", func(t *testing.T) { id := utils.GenerateRandomChatID() require.NotEmpty(t, id) require.Contains(t, id, "chatcmpl-") require.Len(t, id, 38) // "chatcmpl-" + 29 random chars }) } func TestMultiModalGuardTextGenerationDeny(t *testing.T) { test.RunTest(t, func(t *testing.T) { // MultiModalGuard text_generation request deny → exercises multi_modal_guard/text/openai.go BuildDenyResponseBody path t.Run("multi modal guard text request deny returns blockedDetails", func(t *testing.T) { host, status := test.NewTestHost(multiModalGuardTextConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "trigger deny"}]}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mmg-text-deny", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for request deny") require.Contains(t, string(local.Data), "blockedDetails") require.Contains(t, string(local.Data), "req-mmg-text-deny") }) // MultiModalGuard text_generation response deny → exercises common/text/openai.go HandleTextGenerationResponseBody BuildDenyResponseBody path t.Run("multi modal guard text response deny returns blockedDetails", func(t *testing.T) { host, status := test.NewTestHost(multiModalGuardTextConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) body := `{"choices": [{"message": {"role": "assistant", "content": "bad response content"}}]}` action := host.CallOnHttpResponseBody([]byte(body)) require.Equal(t, types.ActionPause, action) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mmg-resp-deny", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for response deny") require.Contains(t, string(local.Data), "blockedDetails") require.Contains(t, string(local.Data), "req-mmg-resp-deny") }) // MultiModalGuard text_generation request pass t.Run("multi modal guard text request pass", func(t *testing.T) { host, status := test.NewTestHost(multiModalGuardTextConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) body := `{"messages": [{"role": "user", "content": "Hello"}]}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mmg-pass", "Data": {"RiskLevel": "low"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) action := host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) }) } func TestMultiModalGuardImageGenerationDeny(t *testing.T) { test.RunTest(t, func(t *testing.T) { // OpenAI image generation request deny → exercises multi_modal_guard/image/openai.go BuildDenyResponseBody path t.Run("openai image request deny returns blockedDetails", func(t *testing.T) { host, status := test.NewTestHost(multiModalGuardImageOpenAIConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/images/generations"}, {":method", "POST"}, }) body := `{"prompt": "generate bad image"}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-openai-deny", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for OpenAI image request deny") require.Contains(t, string(local.Data), "blockedDetails") require.Contains(t, string(local.Data), "req-img-openai-deny") }) // OpenAI image generation request pass t.Run("openai image request pass", func(t *testing.T) { host, status := test.NewTestHost(multiModalGuardImageOpenAIConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/images/generations"}, {":method", "POST"}, }) body := `{"prompt": "a cute cat"}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-pass", "Data": {"RiskLevel": "low"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) action := host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) // Qwen image generation request deny → exercises multi_modal_guard/image/qwen.go BuildDenyResponseBody path t.Run("qwen image request deny returns blockedDetails", func(t *testing.T) { host, status := test.NewTestHost(multiModalGuardImageQwenConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/images/generations"}, {":method", "POST"}, }) body := `{"input": {"prompt": "generate bad image"}}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-qwen-deny", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for Qwen image request deny") require.Contains(t, string(local.Data), "blockedDetails") require.Contains(t, string(local.Data), "req-img-qwen-deny") }) // Qwen image generation request pass t.Run("qwen image request pass", func(t *testing.T) { host, status := test.NewTestHost(multiModalGuardImageQwenConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/images/generations"}, {":method", "POST"}, }) body := `{"input": {"prompt": "a cute cat"}}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-qwen-pass", "Data": {"RiskLevel": "low"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) action := host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) }) } func TestMCPRequestDeny(t *testing.T) { test.RunTest(t, func(t *testing.T) { // MCP request deny → exercises multi_modal_guard/mcp/mcp.go HandleMcpRequestBody BuildDenyResponseBody path t.Run("mcp request deny returns blockedDetails", func(t *testing.T) { host, status := test.NewTestHost(mcpRequestConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/mcp/call"}, {":method", "POST"}, }) body := `{"method": "tools/call", "params": {"arguments": "bad request content"}}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mcp-deny", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for MCP request deny") require.Contains(t, string(local.Data), "blockedDetails") require.Contains(t, string(local.Data), "req-mcp-deny") }) // MCP request pass t.Run("mcp request pass", func(t *testing.T) { host, status := test.NewTestHost(mcpRequestConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/mcp/call"}, {":method", "POST"}, }) body := `{"method": "tools/call", "params": {"arguments": "safe content"}}` require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body))) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mcp-pass", "Data": {"RiskLevel": "low"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) action := host.GetHttpStreamAction() require.Equal(t, types.ActionContinue, action) host.CompleteHttp() }) // MCP request skip non-tool-call method t.Run("mcp request skip non-tool-call", func(t *testing.T) { host, status := test.NewTestHost(mcpRequestConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/mcp/call"}, {":method", "POST"}, }) body := `{"method": "resources/list", "params": {}}` action := host.CallOnHttpRequestBody([]byte(body)) require.Equal(t, types.ActionContinue, action) }) }) } func TestTextModerationPlusResponseDeny(t *testing.T) { test.RunTest(t, func(t *testing.T) { // TextModerationPlus response deny → exercises text_moderation_plus/text (via common/text) BuildDenyResponseBody response path t.Run("text moderation plus response deny returns blockedDetails", func(t *testing.T) { host, status := test.NewTestHost(basicConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, }) host.CallOnHttpResponseHeaders([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }) body := `{"choices": [{"message": {"role": "assistant", "content": "bad response"}}]}` action := host.CallOnHttpResponseBody([]byte(body)) require.Equal(t, types.ActionPause, action) securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-tmp-resp-deny", "Data": {"RiskLevel": "high"}}` host.CallOnHttpCall([][2]string{ {":status", "200"}, {"content-type", "application/json"}, }, []byte(securityResponse)) local := host.GetLocalResponse() require.NotNil(t, local, "expected SendHttpResponse for response deny") require.Contains(t, string(local.Data), "blockedDetails") require.Contains(t, string(local.Data), "req-tmp-resp-deny") // Verify OpenAI completion shape wrapper type openAIChatCompletion struct { Choices []struct { Message struct { Content string `json:"content"` } `json:"message"` } `json:"choices"` } var outer openAIChatCompletion require.NoError(t, json.Unmarshal(local.Data, &outer)) require.Len(t, outer.Choices, 1) var deny cfg.DenyResponseBody require.NoError(t, json.Unmarshal([]byte(outer.Choices[0].Message.Content), &deny)) require.Equal(t, "req-tmp-resp-deny", deny.RequestId) require.Equal(t, 200, deny.GuardCode) require.NotEmpty(t, deny.BlockedDetails) }) }) } func TestBuildDenyResponseBody(t *testing.T) { makeConfig := func(contentBar, promptBar string) cfg.AISecurityConfig { return cfg.AISecurityConfig{ ContentModerationLevelBar: contentBar, PromptAttackLevelBar: promptBar, SensitiveDataLevelBar: "S4", MaliciousUrlLevelBar: "max", ModelHallucinationLevelBar: "max", Action: cfg.MultiModalGuard, } } t.Run("guardCode equals response.Code", func(t *testing.T) { resp := cfg.Response{ Code: 200, RequestId: "req-123", Data: cfg.Data{}, } body, err := cfg.BuildDenyResponseBody(resp, makeConfig("high", "high"), "") require.NoError(t, err) var result cfg.DenyResponseBody require.NoError(t, json.Unmarshal(body, &result)) require.Equal(t, 200, result.GuardCode) require.Equal(t, "req-123", result.RequestId) }) t.Run("blockedDetails from Data.Detail", func(t *testing.T) { resp := cfg.Response{ Code: 200, RequestId: "req-456", Data: cfg.Data{ Detail: []cfg.Detail{ {Type: cfg.ContentModerationType, Level: "high", Suggestion: "block"}, {Type: cfg.PromptAttackType, Level: "low", Suggestion: "block"}, }, }, } config := makeConfig("high", "high") body, err := cfg.BuildDenyResponseBody(resp, config, "") require.NoError(t, err) var result cfg.DenyResponseBody require.NoError(t, json.Unmarshal(body, &result)) // only the contentModeration entry meets the "high" bar; promptAttack at "low" does not require.Len(t, result.BlockedDetails, 1) require.Equal(t, cfg.ContentModerationType, result.BlockedDetails[0].Type) require.Equal(t, "high", result.BlockedDetails[0].Level) }) t.Run("blockedDetails fallback from RiskLevel when Detail is empty", func(t *testing.T) { resp := cfg.Response{ Code: 200, RequestId: "req-789", Data: cfg.Data{ RiskLevel: "high", // Detail deliberately empty }, } config := makeConfig("high", "high") body, err := cfg.BuildDenyResponseBody(resp, config, "") require.NoError(t, err) var result cfg.DenyResponseBody require.NoError(t, json.Unmarshal(body, &result)) require.NotEmpty(t, result.BlockedDetails, "expected fallback detail from RiskLevel") require.Equal(t, cfg.ContentModerationType, result.BlockedDetails[0].Type) require.Equal(t, "high", result.BlockedDetails[0].Level) require.Equal(t, "block", result.BlockedDetails[0].Suggestion) }) t.Run("blockedDetails fallback from AttackLevel when Detail is empty", func(t *testing.T) { resp := cfg.Response{ Code: 200, RequestId: "req-abc", Data: cfg.Data{ AttackLevel: "high", // Detail deliberately empty }, } config := makeConfig("high", "high") body, err := cfg.BuildDenyResponseBody(resp, config, "") require.NoError(t, err) var result cfg.DenyResponseBody require.NoError(t, json.Unmarshal(body, &result)) require.NotEmpty(t, result.BlockedDetails, "expected fallback detail from AttackLevel") require.Equal(t, cfg.PromptAttackType, result.BlockedDetails[0].Type) require.Equal(t, "high", result.BlockedDetails[0].Level) require.Equal(t, "block", result.BlockedDetails[0].Suggestion) }) t.Run("blockedDetails empty when risk levels below threshold", func(t *testing.T) { resp := cfg.Response{ Code: 200, RequestId: "req-def", Data: cfg.Data{ RiskLevel: "low", AttackLevel: "low", }, } // threshold is "high", so "low" must not produce fallback entries config := makeConfig("high", "high") body, err := cfg.BuildDenyResponseBody(resp, config, "") require.NoError(t, err) var result cfg.DenyResponseBody require.NoError(t, json.Unmarshal(body, &result)) require.Empty(t, result.BlockedDetails) }) }