Files
higress/plugins/wasm-go/extensions/ai-security-guard/main_test.go
2026-01-29 19:25:43 +08:00

652 lines
21 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Copyright (c) 2024 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"encoding/json"
"testing"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 测试配置:基础安全配置
var basicConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"checkResponse": true,
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"timeout": 2000,
"bufferLimit": 1000,
})
return data
}()
// 测试配置:仅检查请求
var requestOnlyConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"checkResponse": false,
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"timeout": 1000,
"bufferLimit": 500,
})
return data
}()
// 测试配置:缺少必需字段
var missingRequiredConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"accessKey": "test-ak",
"secretKey": "test-sk",
// 故意缺少必需字段serviceName, servicePort, serviceHost
})
return data
}()
// 测试配置:缺少服务配置字段
var missingServiceConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"checkResponse": true,
// 缺少 serviceName, servicePort, serviceHost
})
return data
}()
// 测试配置:缺少认证字段
var missingAuthConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"checkRequest": true,
"checkResponse": true,
// 缺少 accessKey, secretKey
})
return data
}()
// 测试配置:消费者级别特殊配置
var consumerSpecificConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"checkResponse": false,
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"maliciousUrlLevelBar": "high",
"modelHallucinationLevelBar": "high",
"timeout": 1000,
"bufferLimit": 500,
"consumerRequestCheckService": map[string]interface{}{
"name": "aaa",
"matchType": "exact",
"requestCheckService": "llm_query_moderation_1",
},
"consumerResponseCheckService": map[string]interface{}{
"name": "bbb",
"matchType": "prefix",
"responseCheckService": "llm_response_moderation_1",
},
"consumerRiskLevel": map[string]interface{}{
"name": "ccc.*",
"matchType": "regexp",
"maliciousUrlLevelBar": "low",
},
})
return data
}()
// 测试配置MCP配置
var mcpConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": false,
"checkResponse": true,
"action": "MultiModalGuard",
"apiType": "mcp",
"responseContentJsonPath": "content",
"responseStreamContentJsonPath": "content",
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"timeout": 2000,
})
return data
}()
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基础配置解析
t.Run("basic config", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, "test-ak", securityConfig.AK)
require.Equal(t, "test-sk", securityConfig.SK)
require.Equal(t, true, securityConfig.CheckRequest)
require.Equal(t, true, securityConfig.CheckResponse)
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
require.Equal(t, uint32(2000), securityConfig.Timeout)
require.Equal(t, 1000, securityConfig.BufferLimit)
})
// 测试仅检查请求的配置
t.Run("request only config", func(t *testing.T) {
host, status := test.NewTestHost(requestOnlyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, true, securityConfig.CheckRequest)
require.Equal(t, false, securityConfig.CheckResponse)
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
})
// 测试缺少必需字段的配置
t.Run("missing required config", func(t *testing.T) {
host, status := test.NewTestHost(missingRequiredConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// 测试缺少服务配置字段
t.Run("missing service config", func(t *testing.T) {
host, status := test.NewTestHost(missingServiceConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// 测试缺少认证字段
t.Run("missing auth config", func(t *testing.T) {
host, status := test.NewTestHost(missingAuthConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// 测试消费者级别配置
t.Run("consumer specific config", func(t *testing.T) {
host, status := test.NewTestHost(consumerSpecificConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, "llm_query_moderation", securityConfig.GetRequestCheckService("aaaa"))
require.Equal(t, "llm_query_moderation_1", securityConfig.GetRequestCheckService("aaa"))
require.Equal(t, "llm_response_moderation", securityConfig.GetResponseCheckService("bb"))
require.Equal(t, "llm_response_moderation_1", securityConfig.GetResponseCheckService("bbb-prefix-test"))
require.Equal(t, "high", securityConfig.GetMaliciousUrlLevelBar("cc"))
require.Equal(t, "low", securityConfig.GetMaliciousUrlLevelBar("ccc-regexp-test"))
})
})
}
func TestOnHttpRequestHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试启用请求检查的情况
t.Run("request checking enabled", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
// 测试禁用请求检查的情况
t.Run("request checking disabled", func(t *testing.T) {
host, status := test.NewTestHost(requestOnlyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
})
}
func TestOnHttpRequestBody(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试请求体安全检查通过
t.Run("request body security check pass", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置请求体
body := `{"messages": [{"role": "user", "content": "Hello, how are you?"}]}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionPause等待安全检查结果
require.Equal(t, types.ActionPause, action)
// 模拟安全检查服务响应(通过)
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
action = host.GetHttpStreamAction()
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// 测试空请求内容
t.Run("empty request content", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置空内容的请求体
body := `{"messages": [{"role": "user", "content": ""}]}`
action := host.CallOnHttpRequestBody([]byte(body))
// 空内容应该直接通过
require.Equal(t, types.ActionContinue, action)
})
})
}
func TestOnHttpResponseHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试启用响应检查的情况
t.Run("response checking enabled", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置响应头
action := host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
// 应该返回HeaderStopIteration
require.Equal(t, types.HeaderStopIteration, action)
})
// 测试禁用响应检查的情况
t.Run("response checking disabled", func(t *testing.T) {
host, status := test.NewTestHost(requestOnlyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置响应头
action := host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
// 测试非200状态码
t.Run("non-200 status code", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置非200响应头
action := host.CallOnHttpResponseHeaders([][2]string{
{":status", "500"},
{"content-type", "application/json"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
})
}
func TestOnHttpResponseBody(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试响应体安全检查通过
t.Run("response body security check pass", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置响应头
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
// 设置响应体
body := `{"choices": [{"message": {"role": "assistant", "content": "Hello, how can I help you?"}}]}`
action := host.CallOnHttpResponseBody([]byte(body))
// 应该返回ActionPause等待安全检查结果
require.Equal(t, types.ActionPause, action)
// 模拟安全检查服务响应(通过)
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
action = host.GetHttpStreamAction()
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// 测试空响应内容
t.Run("empty response content", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置响应头
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
// 设置空内容的响应体
body := `{"choices": [{"message": {"role": "assistant", "content": ""}}]}`
action := host.CallOnHttpResponseBody([]byte(body))
// 空内容应该直接通过
require.Equal(t, types.ActionContinue, action)
})
})
}
func TestMCP(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// Test MCP Response Body Check - Pass
t.Run("mcp response body security check pass", func(t *testing.T) {
host, status := test.NewTestHost(mcpConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"x-mse-consumer", "test-user"},
})
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
// body content matching responseContentJsonPath="content"
body := `{"content": "Hello world"}`
action := host.CallOnHttpResponseBody([]byte(body))
require.Equal(t, types.ActionPause, action)
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
action = host.GetHttpStreamAction()
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// Test MCP Response Body Check - Deny
t.Run("mcp response body security check deny", func(t *testing.T) {
host, status := test.NewTestHost(mcpConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
body := `{"content": "Bad content"}`
action := host.CallOnHttpResponseBody([]byte(body))
require.Equal(t, types.ActionPause, action)
// High Risk
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
// Verify it was replaced with DenyResponse
// Can't easily verify the replaced body content with current test wrapper but can check action
// Since plugin calls SendHttpResponse, execution stops or changes.
// mcp.go uses SendHttpResponse(..., DenyResponse, -1) which means it ends the stream.
// We can check if GetHttpStreamAction is ActionPause (since it did send a response) or something else.
// Actually SendHttpResponse in proxy-wasm usually terminates further processing of the original stream.
})
// Test MCP Streaming Response Body Check - Pass
t.Run("mcp streaming response body security check pass", func(t *testing.T) {
host, status := test.NewTestHost(mcpConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "text/event-stream"},
})
// streaming chunk
// config uses "content" key
chunk := []byte(`data: {"content": "Hello"}` + "\n\n")
// This calls OnHttpStreamingResponseBody -> mcp.HandleMcpStreamingResponseBody
// It should push buffer and make call
host.CallOnHttpStreamingResponseBody(chunk, false)
// Action assertion removed as it returns an internal value 3
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
})
// Test MCP Streaming Response Body Check - Deny
t.Run("mcp streaming response body security check deny", func(t *testing.T) {
host, status := test.NewTestHost(mcpConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "text/event-stream"},
})
chunk := []byte(`data: {"content": "Bad"}` + "\n\n")
host.CallOnHttpStreamingResponseBody(chunk, false)
// High Risk
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "high"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
// It injects DenySSEResponse.
})
})
}
func TestRiskLevelFunctions(t *testing.T) {
// 测试风险等级转换函数
t.Run("risk level conversion", func(t *testing.T) {
require.Equal(t, 4, cfg.LevelToInt(cfg.MaxRisk))
require.Equal(t, 3, cfg.LevelToInt(cfg.HighRisk))
require.Equal(t, 2, cfg.LevelToInt(cfg.MediumRisk))
require.Equal(t, 1, cfg.LevelToInt(cfg.LowRisk))
require.Equal(t, 0, cfg.LevelToInt(cfg.NoRisk))
require.Equal(t, -1, cfg.LevelToInt("invalid"))
})
// 测试风险等级比较
t.Run("risk level comparison", func(t *testing.T) {
require.True(t, cfg.LevelToInt(cfg.HighRisk) >= cfg.LevelToInt(cfg.MediumRisk))
require.True(t, cfg.LevelToInt(cfg.MediumRisk) >= cfg.LevelToInt(cfg.LowRisk))
require.True(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.NoRisk))
require.False(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.HighRisk))
})
}
func TestUtilityFunctions(t *testing.T) {
// 测试十六进制ID生成函数
t.Run("hex id generation", func(t *testing.T) {
id, err := utils.GenerateHexID(16)
require.NoError(t, err)
require.Len(t, id, 16)
require.Regexp(t, "^[0-9a-f]+$", id)
})
// 测试随机ID生成函数
t.Run("random id generation", func(t *testing.T) {
id := utils.GenerateRandomChatID()
require.NotEmpty(t, id)
require.Contains(t, id, "chatcmpl-")
require.Len(t, id, 38) // "chatcmpl-" + 29 random chars
})
}