Files
higress/plugins/wasm-go/extensions/ai-security-guard/main_test.go

1222 lines
42 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Copyright (c) 2024 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"encoding/json"
"testing"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 测试配置:基础安全配置
var basicConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"checkResponse": true,
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"timeout": 2000,
"bufferLimit": 1000,
})
return data
}()
// 测试配置:仅检查请求
var requestOnlyConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"checkResponse": false,
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"timeout": 1000,
"bufferLimit": 500,
})
return data
}()
// 测试配置:缺少必需字段
var missingRequiredConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"accessKey": "test-ak",
"secretKey": "test-sk",
// 故意缺少必需字段serviceName, servicePort, serviceHost
})
return data
}()
// 测试配置:缺少服务配置字段
var missingServiceConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"checkResponse": true,
// 缺少 serviceName, servicePort, serviceHost
})
return data
}()
// 测试配置:缺少认证字段
var missingAuthConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"checkRequest": true,
"checkResponse": true,
// 缺少 accessKey, secretKey
})
return data
}()
// 测试配置:消费者级别特殊配置
var consumerSpecificConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"checkResponse": false,
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"maliciousUrlLevelBar": "high",
"modelHallucinationLevelBar": "high",
"timeout": 1000,
"bufferLimit": 500,
"consumerRequestCheckService": map[string]interface{}{
"name": "aaa",
"matchType": "exact",
"requestCheckService": "llm_query_moderation_1",
},
"consumerResponseCheckService": map[string]interface{}{
"name": "bbb",
"matchType": "prefix",
"responseCheckService": "llm_response_moderation_1",
},
"consumerRiskLevel": map[string]interface{}{
"name": "ccc.*",
"matchType": "regexp",
"maliciousUrlLevelBar": "low",
},
})
return data
}()
// 测试配置MCP配置
var mcpConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": false,
"checkResponse": true,
"action": "MultiModalGuard",
"apiType": "mcp",
"responseContentJsonPath": "content",
"responseStreamContentJsonPath": "content",
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"timeout": 2000,
})
return data
}()
// 测试配置MCP配置启用请求检查
var mcpRequestConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"checkResponse": false,
"action": "MultiModalGuard",
"apiType": "mcp",
"requestContentJsonPath": "params.arguments",
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"timeout": 2000,
})
return data
}()
// 测试配置MultiModalGuard 文本生成
var multiModalGuardTextConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"checkResponse": true,
"action": "MultiModalGuard",
"apiType": "text_generation",
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"timeout": 2000,
"bufferLimit": 1000,
})
return data
}()
// 测试配置MultiModalGuard OpenAI 图像生成
var multiModalGuardImageOpenAIConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"checkResponse": true,
"action": "MultiModalGuard",
"apiType": "image_generation",
"providerType": "openai",
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"timeout": 2000,
})
return data
}()
// 测试配置MultiModalGuard Qwen 图像生成
var multiModalGuardImageQwenConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"checkResponse": true,
"action": "MultiModalGuard",
"apiType": "image_generation",
"providerType": "qwen",
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"timeout": 2000,
})
return data
}()
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基础配置解析
t.Run("basic config", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, "test-ak", securityConfig.AK)
require.Equal(t, "test-sk", securityConfig.SK)
require.Equal(t, true, securityConfig.CheckRequest)
require.Equal(t, true, securityConfig.CheckResponse)
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
require.Equal(t, uint32(2000), securityConfig.Timeout)
require.Equal(t, 1000, securityConfig.BufferLimit)
})
// 测试仅检查请求的配置
t.Run("request only config", func(t *testing.T) {
host, status := test.NewTestHost(requestOnlyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, true, securityConfig.CheckRequest)
require.Equal(t, false, securityConfig.CheckResponse)
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
})
// 测试缺少必需字段的配置
t.Run("missing required config", func(t *testing.T) {
host, status := test.NewTestHost(missingRequiredConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// 测试缺少服务配置字段
t.Run("missing service config", func(t *testing.T) {
host, status := test.NewTestHost(missingServiceConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// 测试缺少认证字段
t.Run("missing auth config", func(t *testing.T) {
host, status := test.NewTestHost(missingAuthConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// 测试消费者级别配置
t.Run("consumer specific config", func(t *testing.T) {
host, status := test.NewTestHost(consumerSpecificConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, "llm_query_moderation", securityConfig.GetRequestCheckService("aaaa"))
require.Equal(t, "llm_query_moderation_1", securityConfig.GetRequestCheckService("aaa"))
require.Equal(t, "llm_response_moderation", securityConfig.GetResponseCheckService("bb"))
require.Equal(t, "llm_response_moderation_1", securityConfig.GetResponseCheckService("bbb-prefix-test"))
require.Equal(t, "high", securityConfig.GetMaliciousUrlLevelBar("cc"))
require.Equal(t, "low", securityConfig.GetMaliciousUrlLevelBar("ccc-regexp-test"))
})
})
}
func TestOnHttpRequestHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试启用请求检查的情况
t.Run("request checking enabled", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
// 测试禁用请求检查的情况
t.Run("request checking disabled", func(t *testing.T) {
host, status := test.NewTestHost(requestOnlyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
})
}
func TestOnHttpRequestBody(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试请求体安全检查通过
t.Run("request body security check pass", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置请求体
body := `{"messages": [{"role": "user", "content": "Hello, how are you?"}]}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionPause等待安全检查结果
require.Equal(t, types.ActionPause, action)
// 模拟安全检查服务响应(通过)
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
action = host.GetHttpStreamAction()
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// 测试空请求内容
t.Run("empty request content", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置空内容的请求体
body := `{"messages": [{"role": "user", "content": ""}]}`
action := host.CallOnHttpRequestBody([]byte(body))
// 空内容应该直接通过
require.Equal(t, types.ActionContinue, action)
})
// TextModerationPlus默认 action含 agent/OpenAI 形态)请求拦截应返回 choices[0].message.content 内的 blockedDetails JSON
t.Run("text moderation plus request deny returns blockedDetails in openai completion shape", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-tmp-deny", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local, "expected SendHttpResponse for request deny")
require.Contains(t, string(local.Data), "blockedDetails")
require.Contains(t, string(local.Data), "req-tmp-deny")
type openAIChatCompletion struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
var outer openAIChatCompletion
require.NoError(t, json.Unmarshal(local.Data, &outer))
require.Len(t, outer.Choices, 1)
var deny cfg.DenyResponseBody
require.NoError(t, json.Unmarshal([]byte(outer.Choices[0].Message.Content), &deny))
require.Equal(t, "req-tmp-deny", deny.RequestId)
require.Equal(t, 200, deny.GuardCode)
require.NotEmpty(t, deny.BlockedDetails)
require.Equal(t, cfg.ContentModerationType, deny.BlockedDetails[0].Type)
})
})
}
func TestOnHttpResponseHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试启用响应检查的情况
t.Run("response checking enabled", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置响应头
action := host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
// 应该返回HeaderStopIteration
require.Equal(t, types.HeaderStopIteration, action)
})
// 测试禁用响应检查的情况
t.Run("response checking disabled", func(t *testing.T) {
host, status := test.NewTestHost(requestOnlyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置响应头
action := host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
// 测试非200状态码
t.Run("non-200 status code", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置非200响应头
action := host.CallOnHttpResponseHeaders([][2]string{
{":status", "500"},
{"content-type", "application/json"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
})
}
func TestOnHttpResponseBody(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试响应体安全检查通过
t.Run("response body security check pass", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置响应头
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
// 设置响应体
body := `{"choices": [{"message": {"role": "assistant", "content": "Hello, how can I help you?"}}]}`
action := host.CallOnHttpResponseBody([]byte(body))
// 应该返回ActionPause等待安全检查结果
require.Equal(t, types.ActionPause, action)
// 模拟安全检查服务响应(通过)
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
action = host.GetHttpStreamAction()
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// 测试空响应内容
t.Run("empty response content", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置响应头
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
// 设置空内容的响应体
body := `{"choices": [{"message": {"role": "assistant", "content": ""}}]}`
action := host.CallOnHttpResponseBody([]byte(body))
// 空内容应该直接通过
require.Equal(t, types.ActionContinue, action)
})
})
}
func TestMCP(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// Test MCP Response Body Check - Pass
t.Run("mcp response body security check pass", func(t *testing.T) {
host, status := test.NewTestHost(mcpConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"x-mse-consumer", "test-user"},
})
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
// body content matching responseContentJsonPath="content"
body := `{"content": "Hello world"}`
action := host.CallOnHttpResponseBody([]byte(body))
require.Equal(t, types.ActionPause, action)
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
action = host.GetHttpStreamAction()
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// Test MCP Response Body Check - Deny
t.Run("mcp response body security check deny", func(t *testing.T) {
host, status := test.NewTestHost(mcpConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
body := `{"content": "Bad content"}`
action := host.CallOnHttpResponseBody([]byte(body))
require.Equal(t, types.ActionPause, action)
// High Risk
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
// Verify it was replaced with DenyResponse
// Can't easily verify the replaced body content with current test wrapper but can check action
// Since plugin calls SendHttpResponse, execution stops or changes.
// mcp.go uses SendHttpResponse(..., DenyResponse, -1) which means it ends the stream.
// We can check if GetHttpStreamAction is ActionPause (since it did send a response) or something else.
// Actually SendHttpResponse in proxy-wasm usually terminates further processing of the original stream.
})
// Test MCP Streaming Response Body Check - Pass
t.Run("mcp streaming response body security check pass", func(t *testing.T) {
host, status := test.NewTestHost(mcpConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "text/event-stream"},
})
// streaming chunk
// config uses "content" key
chunk := []byte(`data: {"content": "Hello"}` + "\n\n")
// This calls OnHttpStreamingResponseBody -> mcp.HandleMcpStreamingResponseBody
// It should push buffer and make call
host.CallOnHttpStreamingResponseBody(chunk, false)
// Action assertion removed as it returns an internal value 3
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
})
// Test MCP Streaming Response Body Check - Deny
t.Run("mcp streaming response body security check deny", func(t *testing.T) {
host, status := test.NewTestHost(mcpConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "text/event-stream"},
})
chunk := []byte(`data: {"content": "Bad"}` + "\n\n")
host.CallOnHttpStreamingResponseBody(chunk, false)
// High Risk
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
// It injects DenySSEResponse.
})
})
}
func TestRiskLevelFunctions(t *testing.T) {
// 测试风险等级转换函数
t.Run("risk level conversion", func(t *testing.T) {
require.Equal(t, 4, cfg.LevelToInt(cfg.MaxRisk))
require.Equal(t, 3, cfg.LevelToInt(cfg.HighRisk))
require.Equal(t, 2, cfg.LevelToInt(cfg.MediumRisk))
require.Equal(t, 1, cfg.LevelToInt(cfg.LowRisk))
require.Equal(t, 0, cfg.LevelToInt(cfg.NoRisk))
require.Equal(t, -1, cfg.LevelToInt("invalid"))
})
// 测试风险等级比较
t.Run("risk level comparison", func(t *testing.T) {
require.True(t, cfg.LevelToInt(cfg.HighRisk) >= cfg.LevelToInt(cfg.MediumRisk))
require.True(t, cfg.LevelToInt(cfg.MediumRisk) >= cfg.LevelToInt(cfg.LowRisk))
require.True(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.NoRisk))
require.False(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.HighRisk))
})
}
func TestUtilityFunctions(t *testing.T) {
// 测试十六进制ID生成函数
t.Run("hex id generation", func(t *testing.T) {
id, err := utils.GenerateHexID(16)
require.NoError(t, err)
require.Len(t, id, 16)
require.Regexp(t, "^[0-9a-f]+$", id)
})
// 测试随机ID生成函数
t.Run("random id generation", func(t *testing.T) {
id := utils.GenerateRandomChatID()
require.NotEmpty(t, id)
require.Contains(t, id, "chatcmpl-")
require.Len(t, id, 38) // "chatcmpl-" + 29 random chars
})
}
func TestMultiModalGuardTextGenerationDeny(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// MultiModalGuard text_generation request deny → exercises multi_modal_guard/text/openai.go BuildDenyResponseBody path
t.Run("multi modal guard text request deny returns blockedDetails", func(t *testing.T) {
host, status := test.NewTestHost(multiModalGuardTextConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mmg-text-deny", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local, "expected SendHttpResponse for request deny")
require.Contains(t, string(local.Data), "blockedDetails")
require.Contains(t, string(local.Data), "req-mmg-text-deny")
})
// MultiModalGuard text_generation response deny → exercises common/text/openai.go HandleTextGenerationResponseBody BuildDenyResponseBody path
t.Run("multi modal guard text response deny returns blockedDetails", func(t *testing.T) {
host, status := test.NewTestHost(multiModalGuardTextConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
body := `{"choices": [{"message": {"role": "assistant", "content": "bad response content"}}]}`
action := host.CallOnHttpResponseBody([]byte(body))
require.Equal(t, types.ActionPause, action)
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mmg-resp-deny", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local, "expected SendHttpResponse for response deny")
require.Contains(t, string(local.Data), "blockedDetails")
require.Contains(t, string(local.Data), "req-mmg-resp-deny")
})
// MultiModalGuard text_generation request pass
t.Run("multi modal guard text request pass", func(t *testing.T) {
host, status := test.NewTestHost(multiModalGuardTextConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
body := `{"messages": [{"role": "user", "content": "Hello"}]}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mmg-pass", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
action := host.GetHttpStreamAction()
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
})
}
func TestMultiModalGuardImageGenerationDeny(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// OpenAI image generation request deny → exercises multi_modal_guard/image/openai.go BuildDenyResponseBody path
t.Run("openai image request deny returns blockedDetails", func(t *testing.T) {
host, status := test.NewTestHost(multiModalGuardImageOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
})
body := `{"prompt": "generate bad image"}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-openai-deny", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local, "expected SendHttpResponse for OpenAI image request deny")
require.Contains(t, string(local.Data), "blockedDetails")
require.Contains(t, string(local.Data), "req-img-openai-deny")
})
// OpenAI image generation request pass
t.Run("openai image request pass", func(t *testing.T) {
host, status := test.NewTestHost(multiModalGuardImageOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
})
body := `{"prompt": "a cute cat"}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-pass", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
action := host.GetHttpStreamAction()
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// Qwen image generation request deny → exercises multi_modal_guard/image/qwen.go BuildDenyResponseBody path
t.Run("qwen image request deny returns blockedDetails", func(t *testing.T) {
host, status := test.NewTestHost(multiModalGuardImageQwenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
})
body := `{"input": {"prompt": "generate bad image"}}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-qwen-deny", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local, "expected SendHttpResponse for Qwen image request deny")
require.Contains(t, string(local.Data), "blockedDetails")
require.Contains(t, string(local.Data), "req-img-qwen-deny")
})
// Qwen image generation request pass
t.Run("qwen image request pass", func(t *testing.T) {
host, status := test.NewTestHost(multiModalGuardImageQwenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
})
body := `{"input": {"prompt": "a cute cat"}}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-qwen-pass", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
action := host.GetHttpStreamAction()
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
})
}
func TestMCPRequestDeny(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// MCP request deny → exercises multi_modal_guard/mcp/mcp.go HandleMcpRequestBody BuildDenyResponseBody path
t.Run("mcp request deny returns blockedDetails", func(t *testing.T) {
host, status := test.NewTestHost(mcpRequestConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/mcp/call"},
{":method", "POST"},
})
body := `{"method": "tools/call", "params": {"arguments": "bad request content"}}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mcp-deny", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local, "expected SendHttpResponse for MCP request deny")
require.Contains(t, string(local.Data), "blockedDetails")
require.Contains(t, string(local.Data), "req-mcp-deny")
})
// MCP request pass
t.Run("mcp request pass", func(t *testing.T) {
host, status := test.NewTestHost(mcpRequestConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/mcp/call"},
{":method", "POST"},
})
body := `{"method": "tools/call", "params": {"arguments": "safe content"}}`
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mcp-pass", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
action := host.GetHttpStreamAction()
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// MCP request skip non-tool-call method
t.Run("mcp request skip non-tool-call", func(t *testing.T) {
host, status := test.NewTestHost(mcpRequestConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/mcp/call"},
{":method", "POST"},
})
body := `{"method": "resources/list", "params": {}}`
action := host.CallOnHttpRequestBody([]byte(body))
require.Equal(t, types.ActionContinue, action)
})
})
}
func TestTextModerationPlusResponseDeny(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// TextModerationPlus response deny → exercises text_moderation_plus/text (via common/text) BuildDenyResponseBody response path
t.Run("text moderation plus response deny returns blockedDetails", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
body := `{"choices": [{"message": {"role": "assistant", "content": "bad response"}}]}`
action := host.CallOnHttpResponseBody([]byte(body))
require.Equal(t, types.ActionPause, action)
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-tmp-resp-deny", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
local := host.GetLocalResponse()
require.NotNil(t, local, "expected SendHttpResponse for response deny")
require.Contains(t, string(local.Data), "blockedDetails")
require.Contains(t, string(local.Data), "req-tmp-resp-deny")
// Verify OpenAI completion shape wrapper
type openAIChatCompletion struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
var outer openAIChatCompletion
require.NoError(t, json.Unmarshal(local.Data, &outer))
require.Len(t, outer.Choices, 1)
var deny cfg.DenyResponseBody
require.NoError(t, json.Unmarshal([]byte(outer.Choices[0].Message.Content), &deny))
require.Equal(t, "req-tmp-resp-deny", deny.RequestId)
require.Equal(t, 200, deny.GuardCode)
require.NotEmpty(t, deny.BlockedDetails)
})
})
}
func TestBuildDenyResponseBody(t *testing.T) {
makeConfig := func(contentBar, promptBar string) cfg.AISecurityConfig {
return cfg.AISecurityConfig{
ContentModerationLevelBar: contentBar,
PromptAttackLevelBar: promptBar,
SensitiveDataLevelBar: "S4",
MaliciousUrlLevelBar: "max",
ModelHallucinationLevelBar: "max",
Action: cfg.MultiModalGuard,
}
}
t.Run("guardCode equals response.Code", func(t *testing.T) {
resp := cfg.Response{
Code: 200,
RequestId: "req-123",
Data: cfg.Data{},
}
body, err := cfg.BuildDenyResponseBody(resp, makeConfig("high", "high"), "")
require.NoError(t, err)
var result cfg.DenyResponseBody
require.NoError(t, json.Unmarshal(body, &result))
require.Equal(t, 200, result.GuardCode)
require.Equal(t, "req-123", result.RequestId)
})
t.Run("blockedDetails from Data.Detail", func(t *testing.T) {
resp := cfg.Response{
Code: 200,
RequestId: "req-456",
Data: cfg.Data{
Detail: []cfg.Detail{
{Type: cfg.ContentModerationType, Level: "high", Suggestion: "block"},
{Type: cfg.PromptAttackType, Level: "low", Suggestion: "block"},
},
},
}
config := makeConfig("high", "high")
body, err := cfg.BuildDenyResponseBody(resp, config, "")
require.NoError(t, err)
var result cfg.DenyResponseBody
require.NoError(t, json.Unmarshal(body, &result))
// only the contentModeration entry meets the "high" bar; promptAttack at "low" does not
require.Len(t, result.BlockedDetails, 1)
require.Equal(t, cfg.ContentModerationType, result.BlockedDetails[0].Type)
require.Equal(t, "high", result.BlockedDetails[0].Level)
})
t.Run("blockedDetails fallback from RiskLevel when Detail is empty", func(t *testing.T) {
resp := cfg.Response{
Code: 200,
RequestId: "req-789",
Data: cfg.Data{
RiskLevel: "high",
// Detail deliberately empty
},
}
config := makeConfig("high", "high")
body, err := cfg.BuildDenyResponseBody(resp, config, "")
require.NoError(t, err)
var result cfg.DenyResponseBody
require.NoError(t, json.Unmarshal(body, &result))
require.NotEmpty(t, result.BlockedDetails, "expected fallback detail from RiskLevel")
require.Equal(t, cfg.ContentModerationType, result.BlockedDetails[0].Type)
require.Equal(t, "high", result.BlockedDetails[0].Level)
require.Equal(t, "block", result.BlockedDetails[0].Suggestion)
})
t.Run("blockedDetails fallback from AttackLevel when Detail is empty", func(t *testing.T) {
resp := cfg.Response{
Code: 200,
RequestId: "req-abc",
Data: cfg.Data{
AttackLevel: "high",
// Detail deliberately empty
},
}
config := makeConfig("high", "high")
body, err := cfg.BuildDenyResponseBody(resp, config, "")
require.NoError(t, err)
var result cfg.DenyResponseBody
require.NoError(t, json.Unmarshal(body, &result))
require.NotEmpty(t, result.BlockedDetails, "expected fallback detail from AttackLevel")
require.Equal(t, cfg.PromptAttackType, result.BlockedDetails[0].Type)
require.Equal(t, "high", result.BlockedDetails[0].Level)
require.Equal(t, "block", result.BlockedDetails[0].Suggestion)
})
t.Run("blockedDetails empty when risk levels below threshold", func(t *testing.T) {
resp := cfg.Response{
Code: 200,
RequestId: "req-def",
Data: cfg.Data{
RiskLevel: "low",
AttackLevel: "low",
},
}
// threshold is "high", so "low" must not produce fallback entries
config := makeConfig("high", "high")
body, err := cfg.BuildDenyResponseBody(resp, config, "")
require.NoError(t, err)
var result cfg.DenyResponseBody
require.NoError(t, json.Unmarshal(body, &result))
require.Empty(t, result.BlockedDetails)
})
}