From e70b9ec437fbd031ca755e3e691a71addd754e86 Mon Sep 17 00:00:00 2001 From: rinfx Date: Wed, 17 Sep 2025 16:13:24 +0800 Subject: [PATCH] update ai-security-guard test (#2928) --- .../extensions/ai-security-guard/main_test.go | 72 ++++++++++--------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-security-guard/main_test.go b/plugins/wasm-go/extensions/ai-security-guard/main_test.go index 40a8f9cc2..5ce57d9a5 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main_test.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main_test.go @@ -26,16 +26,18 @@ import ( // 测试配置:基础安全配置 var basicConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ - "serviceName": "security-service", - "servicePort": 8080, - "serviceHost": "security.example.com", - "accessKey": "test-ak", - "secretKey": "test-sk", - "checkRequest": true, - "checkResponse": true, - "riskLevelBar": "high", - "timeout": 2000, - "bufferLimit": 1000, + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "checkResponse": true, + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "sensitiveDataLevelBar": "S3", + "timeout": 2000, + "bufferLimit": 1000, }) return data }() @@ -43,16 +45,18 @@ var basicConfig = func() json.RawMessage { // 测试配置:仅检查请求 var requestOnlyConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ - "serviceName": "security-service", - "servicePort": 8080, - "serviceHost": "security.example.com", - "accessKey": "test-ak", - "secretKey": "test-sk", - "checkRequest": true, - "checkResponse": false, - "riskLevelBar": "medium", - "timeout": 1000, - "bufferLimit": 500, + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "checkResponse": false, + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "sensitiveDataLevelBar": "S3", + "timeout": 1000, + "bufferLimit": 500, }) return data }() @@ -108,7 +112,9 @@ func TestParseConfig(t *testing.T) { require.Equal(t, "test-sk", securityConfig.sk) require.Equal(t, true, securityConfig.checkRequest) require.Equal(t, true, securityConfig.checkResponse) - require.Equal(t, "high", securityConfig.riskLevelBar) + require.Equal(t, "high", securityConfig.contentModerationLevelBar) + require.Equal(t, "high", securityConfig.promptAttackLevelBar) + require.Equal(t, "S3", securityConfig.sensitiveDataLevelBar) require.Equal(t, uint32(2000), securityConfig.timeout) require.Equal(t, 1000, securityConfig.bufferLimit) }) @@ -125,7 +131,9 @@ func TestParseConfig(t *testing.T) { securityConfig := config.(*AISecurityConfig) require.Equal(t, true, securityConfig.checkRequest) require.Equal(t, false, securityConfig.checkResponse) - require.Equal(t, "medium", securityConfig.riskLevelBar) + require.Equal(t, "high", securityConfig.contentModerationLevelBar) + require.Equal(t, "high", securityConfig.promptAttackLevelBar) + require.Equal(t, "S3", securityConfig.sensitiveDataLevelBar) }) // 测试缺少必需字段的配置 @@ -323,20 +331,20 @@ func TestOnHttpResponseHeaders(t *testing.T) { func TestRiskLevelFunctions(t *testing.T) { // 测试风险等级转换函数 t.Run("risk level conversion", func(t *testing.T) { - require.Equal(t, 4, riskLevelToInt(MaxRisk)) - require.Equal(t, 3, riskLevelToInt(HighRisk)) - require.Equal(t, 2, riskLevelToInt(MediumRisk)) - require.Equal(t, 1, riskLevelToInt(LowRisk)) - require.Equal(t, 0, riskLevelToInt(NoRisk)) - require.Equal(t, -1, riskLevelToInt("invalid")) + require.Equal(t, 4, levelToInt(MaxRisk)) + require.Equal(t, 3, levelToInt(HighRisk)) + require.Equal(t, 2, levelToInt(MediumRisk)) + require.Equal(t, 1, levelToInt(LowRisk)) + require.Equal(t, 0, levelToInt(NoRisk)) + require.Equal(t, -1, levelToInt("invalid")) }) // 测试风险等级比较 t.Run("risk level comparison", func(t *testing.T) { - require.True(t, riskLevelToInt(HighRisk) >= riskLevelToInt(MediumRisk)) - require.True(t, riskLevelToInt(MediumRisk) >= riskLevelToInt(LowRisk)) - require.True(t, riskLevelToInt(LowRisk) >= riskLevelToInt(NoRisk)) - require.False(t, riskLevelToInt(LowRisk) >= riskLevelToInt(HighRisk)) + require.True(t, levelToInt(HighRisk) >= levelToInt(MediumRisk)) + require.True(t, levelToInt(MediumRisk) >= levelToInt(LowRisk)) + require.True(t, levelToInt(LowRisk) >= levelToInt(NoRisk)) + require.False(t, levelToInt(LowRisk) >= levelToInt(HighRisk)) }) }