[feat] ai-security-guard refactor & support checking multimoadl input (#3075)

This commit is contained in:
rinfx
2025-12-04 16:33:59 +08:00
committed by GitHub
parent 3e24d66079
commit 896bcacf4c
15 changed files with 1932 additions and 1014 deletions

View File

@@ -18,6 +18,8 @@ import (
"encoding/json"
"testing"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
@@ -143,16 +145,16 @@ func TestParseConfig(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*AISecurityConfig)
require.Equal(t, "test-ak", securityConfig.ak)
require.Equal(t, "test-sk", securityConfig.sk)
require.Equal(t, true, securityConfig.checkRequest)
require.Equal(t, true, securityConfig.checkResponse)
require.Equal(t, "high", securityConfig.contentModerationLevelBar)
require.Equal(t, "high", securityConfig.promptAttackLevelBar)
require.Equal(t, "S3", securityConfig.sensitiveDataLevelBar)
require.Equal(t, uint32(2000), securityConfig.timeout)
require.Equal(t, 1000, securityConfig.bufferLimit)
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, "test-ak", securityConfig.AK)
require.Equal(t, "test-sk", securityConfig.SK)
require.Equal(t, true, securityConfig.CheckRequest)
require.Equal(t, true, securityConfig.CheckResponse)
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
require.Equal(t, uint32(2000), securityConfig.Timeout)
require.Equal(t, 1000, securityConfig.BufferLimit)
})
// 测试仅检查请求的配置
@@ -164,12 +166,12 @@ func TestParseConfig(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*AISecurityConfig)
require.Equal(t, true, securityConfig.checkRequest)
require.Equal(t, false, securityConfig.checkResponse)
require.Equal(t, "high", securityConfig.contentModerationLevelBar)
require.Equal(t, "high", securityConfig.promptAttackLevelBar)
require.Equal(t, "S3", securityConfig.sensitiveDataLevelBar)
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, true, securityConfig.CheckRequest)
require.Equal(t, false, securityConfig.CheckResponse)
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
})
// 测试缺少必需字段的配置
@@ -202,13 +204,13 @@ func TestParseConfig(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*AISecurityConfig)
require.Equal(t, "llm_query_moderation", securityConfig.getRequestCheckService("aaaa"))
require.Equal(t, "llm_query_moderation_1", securityConfig.getRequestCheckService("aaa"))
require.Equal(t, "llm_response_moderation", securityConfig.getResponseCheckService("bb"))
require.Equal(t, "llm_response_moderation_1", securityConfig.getResponseCheckService("bbb-prefix-test"))
require.Equal(t, "high", securityConfig.getMaliciousUrlLevelBar("cc"))
require.Equal(t, "low", securityConfig.getMaliciousUrlLevelBar("ccc-regexp-test"))
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, "llm_query_moderation", securityConfig.GetRequestCheckService("aaaa"))
require.Equal(t, "llm_query_moderation_1", securityConfig.GetRequestCheckService("aaa"))
require.Equal(t, "llm_response_moderation", securityConfig.GetResponseCheckService("bb"))
require.Equal(t, "llm_response_moderation_1", securityConfig.GetResponseCheckService("bbb-prefix-test"))
require.Equal(t, "high", securityConfig.GetMaliciousUrlLevelBar("cc"))
require.Equal(t, "low", securityConfig.GetMaliciousUrlLevelBar("ccc-regexp-test"))
})
})
}
@@ -385,62 +387,27 @@ func TestOnHttpResponseHeaders(t *testing.T) {
func TestRiskLevelFunctions(t *testing.T) {
// 测试风险等级转换函数
t.Run("risk level conversion", func(t *testing.T) {
require.Equal(t, 4, levelToInt(MaxRisk))
require.Equal(t, 3, levelToInt(HighRisk))
require.Equal(t, 2, levelToInt(MediumRisk))
require.Equal(t, 1, levelToInt(LowRisk))
require.Equal(t, 0, levelToInt(NoRisk))
require.Equal(t, -1, levelToInt("invalid"))
require.Equal(t, 4, cfg.LevelToInt(cfg.MaxRisk))
require.Equal(t, 3, cfg.LevelToInt(cfg.HighRisk))
require.Equal(t, 2, cfg.LevelToInt(cfg.MediumRisk))
require.Equal(t, 1, cfg.LevelToInt(cfg.LowRisk))
require.Equal(t, 0, cfg.LevelToInt(cfg.NoRisk))
require.Equal(t, -1, cfg.LevelToInt("invalid"))
})
// 测试风险等级比较
t.Run("risk level comparison", func(t *testing.T) {
require.True(t, levelToInt(HighRisk) >= levelToInt(MediumRisk))
require.True(t, levelToInt(MediumRisk) >= levelToInt(LowRisk))
require.True(t, levelToInt(LowRisk) >= levelToInt(NoRisk))
require.False(t, levelToInt(LowRisk) >= levelToInt(HighRisk))
require.True(t, cfg.LevelToInt(cfg.HighRisk) >= cfg.LevelToInt(cfg.MediumRisk))
require.True(t, cfg.LevelToInt(cfg.MediumRisk) >= cfg.LevelToInt(cfg.LowRisk))
require.True(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.NoRisk))
require.False(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.HighRisk))
})
}
func TestUtilityFunctions(t *testing.T) {
// 测试URL编码函数
t.Run("url encoding", func(t *testing.T) {
original := "test+string:with=special&chars@$"
encoded := urlEncoding(original)
require.NotEqual(t, original, encoded)
require.Contains(t, encoded, "%2B") // + 应该被编码
require.Contains(t, encoded, "%3A") // : 应该被编码
require.Contains(t, encoded, "%3D") // = 应该被编码
require.Contains(t, encoded, "%26") // & 应该被编码
})
// 测试HMAC-SHA1签名函数
t.Run("hmac sha1", func(t *testing.T) {
message := "test message"
secret := "test secret"
signature := hmacSha1(message, secret)
require.NotEmpty(t, signature)
require.NotEqual(t, message, signature)
})
// 测试签名生成函数
t.Run("signature generation", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
params := map[string]string{
"key1": "value1",
"key2": "value2",
}
secret := "test-secret"
signature := getSign(params, secret)
require.NotEmpty(t, signature)
})
// 测试十六进制ID生成函数
t.Run("hex id generation", func(t *testing.T) {
id, err := generateHexID(16)
id, err := utils.GenerateHexID(16)
require.NoError(t, err)
require.Len(t, id, 16)
require.Regexp(t, "^[0-9a-f]+$", id)
@@ -448,7 +415,7 @@ func TestUtilityFunctions(t *testing.T) {
// 测试随机ID生成函数
t.Run("random id generation", func(t *testing.T) {
id := generateRandomID()
id := utils.GenerateRandomChatID()
require.NotEmpty(t, id)
require.Contains(t, id, "chatcmpl-")
require.Len(t, id, 38) // "chatcmpl-" + 29 random chars