mirror of
https://github.com/alibaba/higress.git
synced 2026-06-04 18:17:33 +08:00
[feat] ai-security-guard refactor & support checking multimoadl input (#3075)
This commit is contained in:
@@ -18,6 +18,8 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -143,16 +145,16 @@ func TestParseConfig(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
securityConfig := config.(*AISecurityConfig)
|
||||
require.Equal(t, "test-ak", securityConfig.ak)
|
||||
require.Equal(t, "test-sk", securityConfig.sk)
|
||||
require.Equal(t, true, securityConfig.checkRequest)
|
||||
require.Equal(t, true, securityConfig.checkResponse)
|
||||
require.Equal(t, "high", securityConfig.contentModerationLevelBar)
|
||||
require.Equal(t, "high", securityConfig.promptAttackLevelBar)
|
||||
require.Equal(t, "S3", securityConfig.sensitiveDataLevelBar)
|
||||
require.Equal(t, uint32(2000), securityConfig.timeout)
|
||||
require.Equal(t, 1000, securityConfig.bufferLimit)
|
||||
securityConfig := config.(*cfg.AISecurityConfig)
|
||||
require.Equal(t, "test-ak", securityConfig.AK)
|
||||
require.Equal(t, "test-sk", securityConfig.SK)
|
||||
require.Equal(t, true, securityConfig.CheckRequest)
|
||||
require.Equal(t, true, securityConfig.CheckResponse)
|
||||
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
|
||||
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
|
||||
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
|
||||
require.Equal(t, uint32(2000), securityConfig.Timeout)
|
||||
require.Equal(t, 1000, securityConfig.BufferLimit)
|
||||
})
|
||||
|
||||
// 测试仅检查请求的配置
|
||||
@@ -164,12 +166,12 @@ func TestParseConfig(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
securityConfig := config.(*AISecurityConfig)
|
||||
require.Equal(t, true, securityConfig.checkRequest)
|
||||
require.Equal(t, false, securityConfig.checkResponse)
|
||||
require.Equal(t, "high", securityConfig.contentModerationLevelBar)
|
||||
require.Equal(t, "high", securityConfig.promptAttackLevelBar)
|
||||
require.Equal(t, "S3", securityConfig.sensitiveDataLevelBar)
|
||||
securityConfig := config.(*cfg.AISecurityConfig)
|
||||
require.Equal(t, true, securityConfig.CheckRequest)
|
||||
require.Equal(t, false, securityConfig.CheckResponse)
|
||||
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
|
||||
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
|
||||
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
|
||||
})
|
||||
|
||||
// 测试缺少必需字段的配置
|
||||
@@ -202,13 +204,13 @@ func TestParseConfig(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
securityConfig := config.(*AISecurityConfig)
|
||||
require.Equal(t, "llm_query_moderation", securityConfig.getRequestCheckService("aaaa"))
|
||||
require.Equal(t, "llm_query_moderation_1", securityConfig.getRequestCheckService("aaa"))
|
||||
require.Equal(t, "llm_response_moderation", securityConfig.getResponseCheckService("bb"))
|
||||
require.Equal(t, "llm_response_moderation_1", securityConfig.getResponseCheckService("bbb-prefix-test"))
|
||||
require.Equal(t, "high", securityConfig.getMaliciousUrlLevelBar("cc"))
|
||||
require.Equal(t, "low", securityConfig.getMaliciousUrlLevelBar("ccc-regexp-test"))
|
||||
securityConfig := config.(*cfg.AISecurityConfig)
|
||||
require.Equal(t, "llm_query_moderation", securityConfig.GetRequestCheckService("aaaa"))
|
||||
require.Equal(t, "llm_query_moderation_1", securityConfig.GetRequestCheckService("aaa"))
|
||||
require.Equal(t, "llm_response_moderation", securityConfig.GetResponseCheckService("bb"))
|
||||
require.Equal(t, "llm_response_moderation_1", securityConfig.GetResponseCheckService("bbb-prefix-test"))
|
||||
require.Equal(t, "high", securityConfig.GetMaliciousUrlLevelBar("cc"))
|
||||
require.Equal(t, "low", securityConfig.GetMaliciousUrlLevelBar("ccc-regexp-test"))
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -385,62 +387,27 @@ func TestOnHttpResponseHeaders(t *testing.T) {
|
||||
func TestRiskLevelFunctions(t *testing.T) {
|
||||
// 测试风险等级转换函数
|
||||
t.Run("risk level conversion", func(t *testing.T) {
|
||||
require.Equal(t, 4, levelToInt(MaxRisk))
|
||||
require.Equal(t, 3, levelToInt(HighRisk))
|
||||
require.Equal(t, 2, levelToInt(MediumRisk))
|
||||
require.Equal(t, 1, levelToInt(LowRisk))
|
||||
require.Equal(t, 0, levelToInt(NoRisk))
|
||||
require.Equal(t, -1, levelToInt("invalid"))
|
||||
require.Equal(t, 4, cfg.LevelToInt(cfg.MaxRisk))
|
||||
require.Equal(t, 3, cfg.LevelToInt(cfg.HighRisk))
|
||||
require.Equal(t, 2, cfg.LevelToInt(cfg.MediumRisk))
|
||||
require.Equal(t, 1, cfg.LevelToInt(cfg.LowRisk))
|
||||
require.Equal(t, 0, cfg.LevelToInt(cfg.NoRisk))
|
||||
require.Equal(t, -1, cfg.LevelToInt("invalid"))
|
||||
})
|
||||
|
||||
// 测试风险等级比较
|
||||
t.Run("risk level comparison", func(t *testing.T) {
|
||||
require.True(t, levelToInt(HighRisk) >= levelToInt(MediumRisk))
|
||||
require.True(t, levelToInt(MediumRisk) >= levelToInt(LowRisk))
|
||||
require.True(t, levelToInt(LowRisk) >= levelToInt(NoRisk))
|
||||
require.False(t, levelToInt(LowRisk) >= levelToInt(HighRisk))
|
||||
require.True(t, cfg.LevelToInt(cfg.HighRisk) >= cfg.LevelToInt(cfg.MediumRisk))
|
||||
require.True(t, cfg.LevelToInt(cfg.MediumRisk) >= cfg.LevelToInt(cfg.LowRisk))
|
||||
require.True(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.NoRisk))
|
||||
require.False(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.HighRisk))
|
||||
})
|
||||
}
|
||||
|
||||
func TestUtilityFunctions(t *testing.T) {
|
||||
// 测试URL编码函数
|
||||
t.Run("url encoding", func(t *testing.T) {
|
||||
original := "test+string:with=special&chars@$"
|
||||
encoded := urlEncoding(original)
|
||||
require.NotEqual(t, original, encoded)
|
||||
require.Contains(t, encoded, "%2B") // + 应该被编码
|
||||
require.Contains(t, encoded, "%3A") // : 应该被编码
|
||||
require.Contains(t, encoded, "%3D") // = 应该被编码
|
||||
require.Contains(t, encoded, "%26") // & 应该被编码
|
||||
})
|
||||
|
||||
// 测试HMAC-SHA1签名函数
|
||||
t.Run("hmac sha1", func(t *testing.T) {
|
||||
message := "test message"
|
||||
secret := "test secret"
|
||||
signature := hmacSha1(message, secret)
|
||||
require.NotEmpty(t, signature)
|
||||
require.NotEqual(t, message, signature)
|
||||
})
|
||||
|
||||
// 测试签名生成函数
|
||||
t.Run("signature generation", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
params := map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
}
|
||||
secret := "test-secret"
|
||||
signature := getSign(params, secret)
|
||||
require.NotEmpty(t, signature)
|
||||
})
|
||||
|
||||
// 测试十六进制ID生成函数
|
||||
t.Run("hex id generation", func(t *testing.T) {
|
||||
id, err := generateHexID(16)
|
||||
id, err := utils.GenerateHexID(16)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, id, 16)
|
||||
require.Regexp(t, "^[0-9a-f]+$", id)
|
||||
@@ -448,7 +415,7 @@ func TestUtilityFunctions(t *testing.T) {
|
||||
|
||||
// 测试随机ID生成函数
|
||||
t.Run("random id generation", func(t *testing.T) {
|
||||
id := generateRandomID()
|
||||
id := utils.GenerateRandomChatID()
|
||||
require.NotEmpty(t, id)
|
||||
require.Contains(t, id, "chatcmpl-")
|
||||
require.Len(t, id, 38) // "chatcmpl-" + 29 random chars
|
||||
|
||||
Reference in New Issue
Block a user