[feature] add checking of maliciousUrl & modelHallucination, and adjust consumer specific configs (#3024)

This commit is contained in:
rinfx
2025-10-28 14:12:54 +08:00
committed by GitHub
parent 2076ded06f
commit 2a320f87a6
4 changed files with 365 additions and 105 deletions

View File

@@ -96,6 +96,42 @@ var missingAuthConfig = func() json.RawMessage {
return data
}()
// 测试配置:消费者级别特殊配置
var consumerSpecificConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"checkResponse": false,
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"maliciousUrlLevelBar": "high",
"modelHallucinationLevelBar": "high",
"timeout": 1000,
"bufferLimit": 500,
"consumerRequestCheckService": map[string]interface{}{
"name": "aaa",
"matchType": "exact",
"requestCheckService": "llm_query_moderation_1",
},
"consumerResponseCheckService": map[string]interface{}{
"name": "bbb",
"matchType": "prefix",
"responseCheckService": "llm_response_moderation_1",
},
"consumerRiskLevel": map[string]interface{}{
"name": "ccc.*",
"matchType": "regexp",
"maliciousUrlLevelBar": "low",
},
})
return data
}()
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基础配置解析
@@ -156,6 +192,24 @@ func TestParseConfig(t *testing.T) {
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// 测试消费者级别配置
t.Run("consumer specific config", func(t *testing.T) {
host, status := test.NewTestHost(consumerSpecificConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*AISecurityConfig)
require.Equal(t, "llm_query_moderation", securityConfig.getRequestCheckService("aaaa"))
require.Equal(t, "llm_query_moderation_1", securityConfig.getRequestCheckService("aaa"))
require.Equal(t, "llm_response_moderation", securityConfig.getResponseCheckService("bb"))
require.Equal(t, "llm_response_moderation_1", securityConfig.getResponseCheckService("bbb-prefix-test"))
require.Equal(t, "high", securityConfig.getMaliciousUrlLevelBar("cc"))
require.Equal(t, "low", securityConfig.getMaliciousUrlLevelBar("ccc-regexp-test"))
})
})
}
@@ -400,25 +454,3 @@ func TestUtilityFunctions(t *testing.T) {
require.Len(t, id, 38) // "chatcmpl-" + 29 random chars
})
}
func TestMarshalFunctions(t *testing.T) {
// 测试marshalStr函数
t.Run("marshal string", func(t *testing.T) {
testStr := "Hello, World!"
marshalled := marshalStr(testStr)
require.Equal(t, testStr, marshalled)
})
// 测试extractMessageFromStreamingBody函数
t.Run("extract streaming body", func(t *testing.T) {
// 使用正确的分隔符每个chunk之间用双换行符分隔
streamingData := []byte(`{"choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"}}]}
{"choices":[{"index":0,"delta":{"role":"assistant","content":" World"}}]}
{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)
extracted := extractMessageFromStreamingBody(streamingData, "choices.0.delta.content")
require.Equal(t, "Hello World", extracted)
})
}