mirror of
https://github.com/alibaba/higress.git
synced 2026-06-10 05:07:30 +08:00
4679 lines
169 KiB
Go
4679 lines
169 KiB
Go
// Copyright (c) 2024 Alibaba Group Holding Ltd.
|
||
//
|
||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
// you may not use this file except in compliance with the License.
|
||
// You may obtain a copy of the License at
|
||
//
|
||
// http://www.apache.org/licenses/LICENSE-2.0
|
||
//
|
||
// Unless required by applicable law or agreed to in writing, software
|
||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
// See the License for the specific language governing permissions and
|
||
// limitations under the License.
|
||
|
||
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"strings"
|
||
"testing"
|
||
|
||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/proxytest"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||
"github.com/higress-group/wasm-go/pkg/test"
|
||
"github.com/stretchr/testify/require"
|
||
"github.com/tidwall/gjson"
|
||
)
|
||
|
||
// 测试配置:基础安全配置
|
||
var basicConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": true,
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
"bufferLimit": 1000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:仅检查请求
|
||
var requestOnlyConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": false,
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 1000,
|
||
"bufferLimit": 500,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:缺少必需字段
|
||
var missingRequiredConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
// 故意缺少必需字段:serviceName, servicePort, serviceHost
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:缺少服务配置字段
|
||
var missingServiceConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": true,
|
||
// 缺少 serviceName, servicePort, serviceHost
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:缺少认证字段
|
||
var missingAuthConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"checkRequest": true,
|
||
"checkResponse": true,
|
||
// 缺少 accessKey, secretKey
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:消费者级别特殊配置
|
||
var consumerSpecificConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": false,
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"maliciousUrlLevelBar": "high",
|
||
"modelHallucinationLevelBar": "high",
|
||
"timeout": 1000,
|
||
"bufferLimit": 500,
|
||
"consumerRequestCheckService": map[string]interface{}{
|
||
"name": "aaa",
|
||
"matchType": "exact",
|
||
"requestCheckService": "llm_query_moderation_1",
|
||
},
|
||
"consumerResponseCheckService": map[string]interface{}{
|
||
"name": "bbb",
|
||
"matchType": "prefix",
|
||
"responseCheckService": "llm_response_moderation_1",
|
||
},
|
||
"consumerRiskLevel": map[string]interface{}{
|
||
"name": "ccc.*",
|
||
"matchType": "regexp",
|
||
"maliciousUrlLevelBar": "low",
|
||
},
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:包含 customLabelLevelBar 和消费者级别覆盖
|
||
var customLabelConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": true,
|
||
"action": "MultiModalGuard",
|
||
"customLabelLevelBar": "high",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"consumerRiskLevel": []map[string]interface{}{
|
||
{
|
||
"name": "exact-user",
|
||
"matchType": "exact",
|
||
"customLabelLevelBar": "low",
|
||
},
|
||
{
|
||
"name": "prefix-",
|
||
"matchType": "prefix",
|
||
"customLabelLevelBar": "medium",
|
||
},
|
||
},
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:脱敏模式配置(riskAction=mask)
|
||
var maskConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": false,
|
||
"action": "MultiModalGuard",
|
||
"riskAction": "mask",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:MCP配置
|
||
var mcpConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": false,
|
||
"checkResponse": true,
|
||
"action": "MultiModalGuard",
|
||
"apiType": "mcp",
|
||
"responseContentJsonPath": "content",
|
||
"responseStreamContentJsonPath": "content",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
var mcpRequestConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": false,
|
||
"action": "MultiModalGuard",
|
||
"apiType": "mcp",
|
||
"requestContentJsonPath": "params.arguments",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:MultiModalGuard 文本生成
|
||
var multiModalGuardTextConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": true,
|
||
"action": "MultiModalGuard",
|
||
"apiType": "text_generation",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
"bufferLimit": 1000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:MultiModalGuard OpenAI 图像生成
|
||
var multiModalGuardImageOpenAIConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": true,
|
||
"action": "MultiModalGuard",
|
||
"apiType": "image_generation",
|
||
"providerType": "openai",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:MultiModalGuard Qwen 图像生成
|
||
var multiModalGuardImageQwenConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": true,
|
||
"action": "MultiModalGuard",
|
||
"apiType": "image_generation",
|
||
"providerType": "qwen",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:ProtocolOriginal MultiModalGuard
|
||
var protocolOriginalConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": false,
|
||
"action": "MultiModalGuard",
|
||
"protocol": "original",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
func withConfigOverrides(base json.RawMessage, overrides map[string]interface{}) json.RawMessage {
|
||
var config map[string]interface{}
|
||
_ = json.Unmarshal(base, &config)
|
||
for k, v := range overrides {
|
||
config[k] = v
|
||
}
|
||
data, _ := json.Marshal(config)
|
||
return data
|
||
}
|
||
|
||
func withStructuredFormat(base json.RawMessage) json.RawMessage {
|
||
return withConfigOverrides(base, map[string]interface{}{
|
||
"openAIDenyResponseFormat": string(cfg.OpenAIDenyResponseFormatStructured),
|
||
})
|
||
}
|
||
|
||
func mustDecodeLegacyDenyContent(t *testing.T, content string) cfg.DenyResponseBody {
|
||
t.Helper()
|
||
var denyBody cfg.DenyResponseBody
|
||
require.NoError(t, json.Unmarshal([]byte(content), &denyBody))
|
||
return denyBody
|
||
}
|
||
|
||
// 测试配置:Embedding API
|
||
var embeddingConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": true,
|
||
"action": "MultiModalGuard",
|
||
"apiType": "embedding",
|
||
"requestContentJsonPath": "input",
|
||
"responseContentJsonPath": "data",
|
||
"responseErrorContentJsonPath": "error.message",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:Embedding API 仅请求检测
|
||
var embeddingRequestOnlyConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": false,
|
||
"action": "MultiModalGuard",
|
||
"apiType": "embedding",
|
||
"requestContentJsonPath": "input",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:Embedding API 仅响应检测
|
||
var embeddingResponseOnlyConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": false,
|
||
"checkResponse": true,
|
||
"action": "MultiModalGuard",
|
||
"apiType": "embedding",
|
||
"responseContentJsonPath": "data",
|
||
"responseErrorContentJsonPath": "error.message",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
func TestParseConfig(t *testing.T) {
|
||
test.RunGoTest(t, func(t *testing.T) {
|
||
// 测试基础配置解析
|
||
t.Run("basic config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
|
||
securityConfig := config.(*cfg.AISecurityConfig)
|
||
require.Equal(t, "test-ak", securityConfig.AK)
|
||
require.Equal(t, "test-sk", securityConfig.SK)
|
||
require.Equal(t, true, securityConfig.CheckRequest)
|
||
require.Equal(t, true, securityConfig.CheckResponse)
|
||
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
|
||
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
|
||
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
|
||
require.Equal(t, uint32(2000), securityConfig.Timeout)
|
||
require.Equal(t, 1000, securityConfig.BufferLimit)
|
||
require.Equal(t, cfg.DefaultResponseFallbackJsonPaths(), securityConfig.ResponseContentFallbackJsonPaths)
|
||
require.Equal(t, cfg.DefaultStreamingResponseFallbackJsonPaths(), securityConfig.ResponseStreamContentFallbackJsonPaths)
|
||
require.Equal(t, cfg.OpenAIDenyResponseFormatLegacy, securityConfig.OpenAIDenyResponseFormat)
|
||
})
|
||
|
||
t.Run("openai deny response format explicit legacy", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withConfigOverrides(requestOnlyConfig, map[string]interface{}{
|
||
"openAIDenyResponseFormat": "legacy",
|
||
}))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
securityConfig := config.(*cfg.AISecurityConfig)
|
||
require.Equal(t, cfg.OpenAIDenyResponseFormatLegacy, securityConfig.OpenAIDenyResponseFormat)
|
||
})
|
||
|
||
t.Run("openai deny response format explicit structured", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withStructuredFormat(requestOnlyConfig))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
securityConfig := config.(*cfg.AISecurityConfig)
|
||
require.Equal(t, cfg.OpenAIDenyResponseFormatStructured, securityConfig.OpenAIDenyResponseFormat)
|
||
})
|
||
|
||
t.Run("invalid openai deny response format", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withConfigOverrides(requestOnlyConfig, map[string]interface{}{
|
||
"openAIDenyResponseFormat": "json",
|
||
}))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||
})
|
||
|
||
t.Run("empty openai deny response format is invalid", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withConfigOverrides(requestOnlyConfig, map[string]interface{}{
|
||
"openAIDenyResponseFormat": "",
|
||
}))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||
})
|
||
|
||
t.Run("consumer risk level cannot override openai deny response format", func(t *testing.T) {
|
||
configJSON, err := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"action": "MultiModalGuard",
|
||
"contentModerationLevelBar": "high",
|
||
"consumerRiskLevel": []map[string]interface{}{
|
||
{
|
||
"name": "consumer-a",
|
||
"matchType": "exact",
|
||
"openAIDenyResponseFormat": "structured",
|
||
},
|
||
},
|
||
})
|
||
require.NoError(t, err)
|
||
var securityConfig cfg.AISecurityConfig
|
||
parseErr := securityConfig.Parse(gjson.ParseBytes(configJSON))
|
||
require.EqualError(t, parseErr, cfg.OpenAIDenyResponseFormatConsumerScopeError)
|
||
})
|
||
|
||
// 测试仅检查请求的配置
|
||
t.Run("request only config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(requestOnlyConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
|
||
securityConfig := config.(*cfg.AISecurityConfig)
|
||
require.Equal(t, true, securityConfig.CheckRequest)
|
||
require.Equal(t, false, securityConfig.CheckResponse)
|
||
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
|
||
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
|
||
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
|
||
})
|
||
|
||
// 测试缺少必需字段的配置
|
||
t.Run("missing required config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(missingRequiredConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||
})
|
||
|
||
// 测试缺少服务配置字段
|
||
t.Run("missing service config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(missingServiceConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||
})
|
||
|
||
// 测试缺少认证字段
|
||
t.Run("missing auth config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(missingAuthConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||
})
|
||
|
||
// 测试消费者级别配置
|
||
t.Run("consumer specific config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(consumerSpecificConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
|
||
securityConfig := config.(*cfg.AISecurityConfig)
|
||
require.Equal(t, "llm_query_moderation", securityConfig.GetRequestCheckService("aaaa"))
|
||
require.Equal(t, "llm_query_moderation_1", securityConfig.GetRequestCheckService("aaa"))
|
||
require.Equal(t, "llm_response_moderation", securityConfig.GetResponseCheckService("bb"))
|
||
require.Equal(t, "llm_response_moderation_1", securityConfig.GetResponseCheckService("bbb-prefix-test"))
|
||
require.Equal(t, "high", securityConfig.GetMaliciousUrlLevelBar("cc"))
|
||
require.Equal(t, "low", securityConfig.GetMaliciousUrlLevelBar("ccc-regexp-test"))
|
||
})
|
||
|
||
t.Run("custom response fallback paths config", func(t *testing.T) {
|
||
configJSON, err := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkResponse": true,
|
||
"responseContentFallbackJsonPaths": []string{"output.text", "choices.0.message.content"},
|
||
"responseStreamContentFallbackJsonPaths": []string{"payload.delta", "delta.text"},
|
||
})
|
||
require.NoError(t, err)
|
||
host, status := test.NewTestHost(configJSON)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
securityConfig := config.(*cfg.AISecurityConfig)
|
||
require.Equal(t, []string{"output.text", "choices.0.message.content"}, securityConfig.ResponseContentFallbackJsonPaths)
|
||
require.Equal(t, []string{"payload.delta", "delta.text"}, securityConfig.ResponseStreamContentFallbackJsonPaths)
|
||
})
|
||
|
||
t.Run("empty response fallback paths disable fallback", func(t *testing.T) {
|
||
configJSON, err := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkResponse": true,
|
||
"responseContentFallbackJsonPaths": []string{},
|
||
"responseStreamContentFallbackJsonPaths": []string{},
|
||
})
|
||
require.NoError(t, err)
|
||
host, status := test.NewTestHost(configJSON)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
securityConfig := config.(*cfg.AISecurityConfig)
|
||
require.Len(t, securityConfig.ResponseContentFallbackJsonPaths, 0)
|
||
require.Len(t, securityConfig.ResponseStreamContentFallbackJsonPaths, 0)
|
||
})
|
||
|
||
t.Run("invalid response fallback paths type", func(t *testing.T) {
|
||
configJSON, err := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkResponse": true,
|
||
"responseContentFallbackJsonPaths": "choices.0.message.content",
|
||
})
|
||
require.NoError(t, err)
|
||
host, status := test.NewTestHost(configJSON)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||
})
|
||
|
||
t.Run("invalid response fallback paths item", func(t *testing.T) {
|
||
configJSON, err := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkResponse": true,
|
||
"responseStreamContentFallbackJsonPaths": []interface{}{"delta.text", ""},
|
||
})
|
||
require.NoError(t, err)
|
||
host, status := test.NewTestHost(configJSON)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||
})
|
||
|
||
t.Run("invalid response fallback paths non-string item", func(t *testing.T) {
|
||
configJSON, err := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkResponse": true,
|
||
"responseStreamContentFallbackJsonPaths": []interface{}{"delta.text", 123},
|
||
})
|
||
require.NoError(t, err)
|
||
host, status := test.NewTestHost(configJSON)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||
})
|
||
|
||
t.Run("invalid contentModerationLevelBar value", func(t *testing.T) {
|
||
configJSON, err := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkResponse": true,
|
||
"contentModerationLevelBar": "invalid",
|
||
})
|
||
require.NoError(t, err)
|
||
host, status := test.NewTestHost(configJSON)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestOnHttpRequestHeaders(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// 测试启用请求检查的情况
|
||
t.Run("request checking enabled", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withStructuredFormat(basicConfig))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// 应该返回ActionContinue
|
||
require.Equal(t, types.ActionContinue, action)
|
||
})
|
||
|
||
// 测试禁用请求检查的情况
|
||
t.Run("request checking disabled", func(t *testing.T) {
|
||
host, status := test.NewTestHost(requestOnlyConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// 应该返回ActionContinue
|
||
require.Equal(t, types.ActionContinue, action)
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestOnHttpRequestBody(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// 测试请求体安全检查通过
|
||
t.Run("request body security check pass", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withStructuredFormat(basicConfig))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 先设置请求头
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// 设置请求体
|
||
body := `{"messages": [{"role": "user", "content": "Hello, how are you?"}]}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
|
||
// 应该返回ActionPause,等待安全检查结果
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// 模拟安全检查服务响应(通过)
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
action = host.GetHttpStreamAction()
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试空请求内容
|
||
t.Run("empty request content", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 先设置请求头
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// 设置空内容的请求体
|
||
body := `{"messages": [{"role": "user", "content": ""}]}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
|
||
// 空内容应该直接通过
|
||
require.Equal(t, types.ActionContinue, action)
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestOnHttpResponseHeaders(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// 测试启用响应检查的情况
|
||
t.Run("response checking enabled", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 先设置请求头
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// 设置响应头
|
||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
// 应该返回HeaderStopIteration
|
||
require.Equal(t, types.HeaderStopIteration, action)
|
||
})
|
||
|
||
// 测试禁用响应检查的情况
|
||
t.Run("response checking disabled", func(t *testing.T) {
|
||
host, status := test.NewTestHost(requestOnlyConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 先设置请求头
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// 设置响应头
|
||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
// 应该返回ActionContinue
|
||
require.Equal(t, types.ActionContinue, action)
|
||
})
|
||
|
||
// 测试非200状态码
|
||
t.Run("non-200 status code", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 先设置请求头
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// 设置非200响应头
|
||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "500"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
// 应该返回ActionContinue
|
||
require.Equal(t, types.ActionContinue, action)
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestOnHttpResponseBody(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// 测试响应体安全检查通过
|
||
t.Run("response body security check pass", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 先设置请求头
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// 设置响应头
|
||
host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
// 设置响应体
|
||
body := `{"choices": [{"message": {"role": "assistant", "content": "Hello, how can I help you?"}}]}`
|
||
action := host.CallOnHttpResponseBody([]byte(body))
|
||
|
||
// 应该返回ActionPause,等待安全检查结果
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// 模拟安全检查服务响应(通过)
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
action = host.GetHttpStreamAction()
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试空响应内容
|
||
t.Run("empty response content", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 先设置请求头
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// 设置响应头
|
||
host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
// 设置空内容的响应体
|
||
body := `{"choices": [{"message": {"role": "assistant", "content": ""}}]}`
|
||
action := host.CallOnHttpResponseBody([]byte(body))
|
||
|
||
// 空内容应该直接通过
|
||
require.Equal(t, types.ActionContinue, action)
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestResponseFallbackExtractionCoverage(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
base := map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkResponse": true,
|
||
"action": "MultiModalGuard",
|
||
"apiType": "text_generation",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
"bufferLimit": 1000,
|
||
}
|
||
|
||
withOverrides := func(overrides map[string]interface{}) json.RawMessage {
|
||
cfgMap := make(map[string]interface{}, len(base)+len(overrides))
|
||
for k, v := range base {
|
||
cfgMap[k] = v
|
||
}
|
||
for k, v := range overrides {
|
||
cfgMap[k] = v
|
||
}
|
||
data, err := json.Marshal(cfgMap)
|
||
require.NoError(t, err)
|
||
return data
|
||
}
|
||
|
||
t.Run("streaming response chunk uses configured fallback path", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withOverrides(map[string]interface{}{
|
||
"responseStreamContentJsonPath": "nonexistent.path",
|
||
"responseStreamContentFallbackJsonPaths": []string{"choices.0.delta.content"},
|
||
}))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "text/event-stream"},
|
||
})
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
chunk := []byte("data: {\"choices\":[{\"delta\":{\"content\":\"hello fallback\"}}]}\n\n")
|
||
host.CallOnHttpStreamingResponseBody(chunk, true)
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-stream-fallback", "Data": {"RiskLevel": "low"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
t.Run("buffered response body uses streaming fallback extraction", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withOverrides(map[string]interface{}{
|
||
"responseStreamContentJsonPath": "nonexistent.path",
|
||
"responseStreamContentFallbackJsonPaths": []string{"choices.0.delta.content"},
|
||
}))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "text/event-stream"},
|
||
})
|
||
|
||
body := "data: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\ndata: [DONE]\n\n"
|
||
host.CallOnHttpResponseBody([]byte(body))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-buffered-stream-fallback", "Data": {"RiskLevel": "low"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
host.CompleteHttp()
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestMCP(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// Test MCP Response Body Check - Pass
|
||
t.Run("mcp response body security check pass", func(t *testing.T) {
|
||
host, status := test.NewTestHost(mcpConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
{"x-mse-consumer", "test-user"},
|
||
})
|
||
|
||
host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
// body content matching responseContentJsonPath="content"
|
||
body := `{"content": "Hello world"}`
|
||
action := host.CallOnHttpResponseBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
action = host.GetHttpStreamAction()
|
||
require.Equal(t, types.ActionContinue, action)
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// Test MCP Response Body Check - Deny
|
||
t.Run("mcp response body security check deny", func(t *testing.T) {
|
||
host, status := test.NewTestHost(mcpConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
body := `{"content": "Bad content"}`
|
||
action := host.CallOnHttpResponseBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// High Risk
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
// Verify it was replaced with DenyResponse
|
||
// Can't easily verify the replaced body content with current test wrapper but can check action
|
||
// Since plugin calls SendHttpResponse, execution stops or changes.
|
||
// mcp.go uses SendHttpResponse(..., DenyResponse, -1) which means it ends the stream.
|
||
// We can check if GetHttpStreamAction is ActionPause (since it did send a response) or something else.
|
||
// Actually SendHttpResponse in proxy-wasm usually terminates further processing of the original stream.
|
||
})
|
||
|
||
// Test MCP Streaming Response Body Check - Pass
|
||
t.Run("mcp streaming response body security check pass", func(t *testing.T) {
|
||
host, status := test.NewTestHost(mcpConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "text/event-stream"},
|
||
})
|
||
|
||
// streaming chunk
|
||
// config uses "content" key
|
||
chunk := []byte(`data: {"content": "Hello"}` + "\n\n")
|
||
// This calls OnHttpStreamingResponseBody -> mcp.HandleMcpStreamingResponseBody
|
||
// It should push buffer and make call
|
||
host.CallOnHttpStreamingResponseBody(chunk, false)
|
||
// Action assertion removed as it returns an internal value 3
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
})
|
||
|
||
// Test MCP Streaming Response Body Check - Deny
|
||
t.Run("mcp streaming response body security check deny", func(t *testing.T) {
|
||
host, status := test.NewTestHost(mcpConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "text/event-stream"},
|
||
})
|
||
|
||
chunk := []byte(`data: {"content": "Bad"}` + "\n\n")
|
||
host.CallOnHttpStreamingResponseBody(chunk, false)
|
||
|
||
// High Risk
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
// It injects DenySSEResponse.
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestGetRiskAction(t *testing.T) {
|
||
// 测试全局默认值
|
||
t.Run("default is block", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{}
|
||
config.SetDefaultValues()
|
||
require.Equal(t, "block", config.GetRiskAction("any-consumer"))
|
||
})
|
||
|
||
// 测试全局配置为 mask,无消费者覆盖
|
||
t.Run("global mask without consumer override", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{RiskAction: "mask"}
|
||
require.Equal(t, "mask", config.GetRiskAction("any-consumer"))
|
||
})
|
||
|
||
// 测试消费者级别覆盖 riskAction
|
||
t.Run("consumer overrides riskAction to mask", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
RiskAction: "block",
|
||
ConsumerRiskLevel: []map[string]interface{}{
|
||
{
|
||
"matcher": cfg.Matcher{Exact: "vip-user"},
|
||
"riskAction": "mask",
|
||
},
|
||
},
|
||
}
|
||
require.Equal(t, "mask", config.GetRiskAction("vip-user"))
|
||
require.Equal(t, "block", config.GetRiskAction("normal-user"))
|
||
})
|
||
|
||
// 测试消费者匹配但未配置 riskAction,fallback 到全局
|
||
t.Run("consumer matched without riskAction falls back to global", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
RiskAction: "mask",
|
||
ConsumerRiskLevel: []map[string]interface{}{
|
||
{
|
||
"matcher": cfg.Matcher{Exact: "some-user"},
|
||
"contentModerationLevelBar": "low",
|
||
},
|
||
},
|
||
}
|
||
require.Equal(t, "mask", config.GetRiskAction("some-user"))
|
||
})
|
||
|
||
// 测试 prefix 匹配
|
||
t.Run("consumer prefix match", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
RiskAction: "block",
|
||
ConsumerRiskLevel: []map[string]interface{}{
|
||
{
|
||
"matcher": cfg.Matcher{Prefix: "test-"},
|
||
"riskAction": "mask",
|
||
},
|
||
},
|
||
}
|
||
require.Equal(t, "mask", config.GetRiskAction("test-user-1"))
|
||
require.Equal(t, "block", config.GetRiskAction("prod-user"))
|
||
})
|
||
|
||
// 测试空 consumer 不匹配任何规则
|
||
t.Run("empty consumer falls back to global", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
RiskAction: "mask",
|
||
ConsumerRiskLevel: []map[string]interface{}{
|
||
{
|
||
"matcher": cfg.Matcher{Exact: "vip"},
|
||
"riskAction": "block",
|
||
},
|
||
},
|
||
}
|
||
require.Equal(t, "mask", config.GetRiskAction(""))
|
||
})
|
||
}
|
||
|
||
func TestEvaluateRiskWithConsumerRiskAction(t *testing.T) {
|
||
// 需要 proxy-wasm host 环境,因为 evaluateRiskMultiModal 调用 proxywasm.LogInfof
|
||
opt := proxytest.NewEmulatorOption().WithVMContext(&types.DefaultVMContext{})
|
||
_, reset := proxytest.NewHostEmulator(opt)
|
||
defer reset()
|
||
|
||
// 测试全局 block,消费者 mask
|
||
t.Run("global block consumer mask", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
RiskAction: "block",
|
||
ContentModerationLevelBar: "max",
|
||
PromptAttackLevelBar: "max",
|
||
SensitiveDataLevelBar: "S4",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
ConsumerRiskLevel: []map[string]interface{}{
|
||
{
|
||
"matcher": cfg.Matcher{Exact: "vip-user"},
|
||
"riskAction": "mask",
|
||
"sensitiveDataLevelBar": "S2",
|
||
},
|
||
},
|
||
}
|
||
data := cfg.Data{
|
||
RiskLevel: "none",
|
||
Detail: []cfg.Detail{
|
||
{Suggestion: "mask", Type: "sensitiveData", Level: "S2",
|
||
Result: []cfg.Result{{Ext: cfg.Ext{Desensitization: "masked"}}}},
|
||
},
|
||
}
|
||
// vip-user 使用 mask 模式,consumer 阈值 S2,Level=S2 >= S2 → RiskMask
|
||
require.Equal(t, cfg.RiskMask, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "vip-user"))
|
||
// normal-user 使用全局 block 模式,全局阈值 S4,Level=S2 < S4 → RiskPass
|
||
require.Equal(t, cfg.RiskPass, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "normal-user"))
|
||
})
|
||
|
||
// 测试全局 mask,消费者 block
|
||
t.Run("global mask consumer block", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
RiskAction: "mask",
|
||
ContentModerationLevelBar: "max",
|
||
PromptAttackLevelBar: "max",
|
||
SensitiveDataLevelBar: "S2",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
ConsumerRiskLevel: []map[string]interface{}{
|
||
{
|
||
"matcher": cfg.Matcher{Exact: "strict-user"},
|
||
"riskAction": "block",
|
||
},
|
||
},
|
||
}
|
||
data := cfg.Data{
|
||
RiskLevel: "none",
|
||
Detail: []cfg.Detail{
|
||
{Suggestion: "mask", Type: "sensitiveData", Level: "S2",
|
||
Result: []cfg.Result{{Ext: cfg.Ext{Desensitization: "masked"}}}},
|
||
},
|
||
}
|
||
// strict-user 使用 block 模式,Level=S2 >= S2 但 Suggestion=mask + dimAction=block → detailTriggersBlock 返回 false(mask suggestion 不触发 block)
|
||
// 实际上 detailTriggersBlock: Suggestion != "block", dimAction == "block" → return exceeds
|
||
// exceeds = S2 >= S2 = true → RiskBlock
|
||
// 所以 strict-user 应该是 RiskBlock
|
||
require.Equal(t, cfg.RiskBlock, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "strict-user"))
|
||
// other-user 使用全局 mask 模式,Level=S2 >= S2 → RiskMask
|
||
require.Equal(t, cfg.RiskMask, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, "other-user"))
|
||
})
|
||
}
|
||
|
||
func TestParseConsumerRiskActionValidation(t *testing.T) {
|
||
// 测试消费者级别 riskAction 无效值
|
||
t.Run("invalid consumer riskAction", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{}
|
||
config.SetDefaultValues()
|
||
configJSON := `{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"consumerRiskLevel": [
|
||
{"name": "user1", "matchType": "exact", "riskAction": "invalid"}
|
||
]
|
||
}`
|
||
err := config.Parse(gjson.Parse(configJSON))
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "invalid riskAction in consumerRiskLevel")
|
||
})
|
||
|
||
// 测试消费者级别 riskAction 有效值
|
||
t.Run("valid consumer riskAction", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{}
|
||
config.SetDefaultValues()
|
||
configJSON := `{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"consumerRiskLevel": [
|
||
{"name": "user1", "matchType": "exact", "riskAction": "mask"}
|
||
]
|
||
}`
|
||
err := config.Parse(gjson.Parse(configJSON))
|
||
require.NoError(t, err)
|
||
require.Equal(t, "mask", config.GetRiskAction("user1"))
|
||
require.Equal(t, "block", config.GetRiskAction("other"))
|
||
})
|
||
}
|
||
|
||
func TestRiskLevelFunctions(t *testing.T) {
|
||
// 测试风险等级转换函数
|
||
t.Run("risk level conversion", func(t *testing.T) {
|
||
require.Equal(t, 4, cfg.LevelToInt(cfg.MaxRisk))
|
||
require.Equal(t, 3, cfg.LevelToInt(cfg.HighRisk))
|
||
require.Equal(t, 2, cfg.LevelToInt(cfg.MediumRisk))
|
||
require.Equal(t, 1, cfg.LevelToInt(cfg.LowRisk))
|
||
require.Equal(t, 0, cfg.LevelToInt(cfg.NoRisk))
|
||
require.Equal(t, -1, cfg.LevelToInt("invalid"))
|
||
})
|
||
|
||
// 测试风险等级比较
|
||
t.Run("risk level comparison", func(t *testing.T) {
|
||
require.True(t, cfg.LevelToInt(cfg.HighRisk) >= cfg.LevelToInt(cfg.MediumRisk))
|
||
require.True(t, cfg.LevelToInt(cfg.MediumRisk) >= cfg.LevelToInt(cfg.LowRisk))
|
||
require.True(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.NoRisk))
|
||
require.False(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.HighRisk))
|
||
})
|
||
}
|
||
|
||
func TestUtilityFunctions(t *testing.T) {
|
||
// 测试十六进制ID生成函数
|
||
t.Run("hex id generation", func(t *testing.T) {
|
||
id, err := utils.GenerateHexID(16)
|
||
require.NoError(t, err)
|
||
require.Len(t, id, 16)
|
||
require.Regexp(t, "^[0-9a-f]+$", id)
|
||
})
|
||
|
||
// 测试随机ID生成函数
|
||
t.Run("random id generation", func(t *testing.T) {
|
||
id := utils.GenerateRandomChatID()
|
||
require.NotEmpty(t, id)
|
||
require.Contains(t, id, "chatcmpl-")
|
||
require.Len(t, id, 38) // "chatcmpl-" + 29 random chars
|
||
})
|
||
}
|
||
|
||
func TestCustomLabelConstant(t *testing.T) {
|
||
// 验证 CustomLabelType 常量值
|
||
require.Equal(t, "customLabel", cfg.CustomLabelType)
|
||
}
|
||
|
||
func TestCustomLabelConfigParsing(t *testing.T) {
|
||
// 测试 customLabelLevelBar 设置为 high
|
||
t.Run("customLabelLevelBar set to high", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{}
|
||
config.SetDefaultValues()
|
||
configJSON := `{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"customLabelLevelBar": "high"
|
||
}`
|
||
err := config.Parse(gjson.Parse(configJSON))
|
||
require.NoError(t, err)
|
||
require.Equal(t, "high", config.CustomLabelLevelBar)
|
||
})
|
||
|
||
// 测试 customLabelLevelBar 设置为 max
|
||
t.Run("customLabelLevelBar set to max", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{}
|
||
config.SetDefaultValues()
|
||
configJSON := `{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"customLabelLevelBar": "max"
|
||
}`
|
||
err := config.Parse(gjson.Parse(configJSON))
|
||
require.NoError(t, err)
|
||
require.Equal(t, "max", config.CustomLabelLevelBar)
|
||
})
|
||
|
||
// 测试 customLabelLevelBar 缺省时默认为 max
|
||
t.Run("customLabelLevelBar defaults to max", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{}
|
||
config.SetDefaultValues()
|
||
configJSON := `{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk"
|
||
}`
|
||
err := config.Parse(gjson.Parse(configJSON))
|
||
require.NoError(t, err)
|
||
require.Equal(t, "max", config.CustomLabelLevelBar)
|
||
})
|
||
|
||
// 测试 customLabelLevelBar 无效值
|
||
t.Run("customLabelLevelBar invalid value", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{}
|
||
config.SetDefaultValues()
|
||
configJSON := `{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"customLabelLevelBar": "invalid"
|
||
}`
|
||
err := config.Parse(gjson.Parse(configJSON))
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "invalid customLabelLevelBar")
|
||
})
|
||
}
|
||
|
||
func TestGetCustomLabelLevelBar(t *testing.T) {
|
||
// 测试消费者精确匹配
|
||
t.Run("consumer exact match", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
CustomLabelLevelBar: "max",
|
||
ConsumerRiskLevel: []map[string]interface{}{
|
||
{
|
||
"matcher": cfg.Matcher{Exact: "exact-user"},
|
||
"customLabelLevelBar": "low",
|
||
},
|
||
},
|
||
}
|
||
require.Equal(t, "low", config.GetCustomLabelLevelBar("exact-user"))
|
||
})
|
||
|
||
// 测试消费者前缀匹配
|
||
t.Run("consumer prefix match", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
CustomLabelLevelBar: "max",
|
||
ConsumerRiskLevel: []map[string]interface{}{
|
||
{
|
||
"matcher": cfg.Matcher{Prefix: "prefix-"},
|
||
"customLabelLevelBar": "medium",
|
||
},
|
||
},
|
||
}
|
||
require.Equal(t, "medium", config.GetCustomLabelLevelBar("prefix-user"))
|
||
})
|
||
|
||
// 测试无匹配回退全局值
|
||
t.Run("no match falls back to global", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
CustomLabelLevelBar: "high",
|
||
ConsumerRiskLevel: []map[string]interface{}{
|
||
{
|
||
"matcher": cfg.Matcher{Exact: "other-user"},
|
||
"customLabelLevelBar": "low",
|
||
},
|
||
},
|
||
}
|
||
require.Equal(t, "high", config.GetCustomLabelLevelBar("unmatched-user"))
|
||
})
|
||
}
|
||
|
||
func TestCustomLabelDetailExceedsThreshold(t *testing.T) {
|
||
// 需要 proxy-wasm host 环境,因为 evaluateRiskMultiModal 调用 proxywasm.LogInfof
|
||
opt := proxytest.NewEmulatorOption().WithVMContext(&types.DefaultVMContext{})
|
||
_, reset := proxytest.NewHostEmulator(opt)
|
||
defer reset()
|
||
|
||
// 测试 customLabel Level=high, threshold=high → 拦截 (true)
|
||
t.Run("level high threshold high blocks", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
RiskAction: "block",
|
||
ContentModerationLevelBar: "max",
|
||
PromptAttackLevelBar: "max",
|
||
SensitiveDataLevelBar: "S4",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
CustomLabelLevelBar: "high",
|
||
}
|
||
data := cfg.Data{
|
||
RiskLevel: "none",
|
||
Detail: []cfg.Detail{
|
||
{Type: cfg.CustomLabelType, Level: "high"},
|
||
},
|
||
}
|
||
require.Equal(t, cfg.RiskBlock, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, ""))
|
||
})
|
||
|
||
// 测试 customLabel Level=none, threshold=high → 不拦截 (false)
|
||
t.Run("level none threshold high passes", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
RiskAction: "block",
|
||
ContentModerationLevelBar: "max",
|
||
PromptAttackLevelBar: "max",
|
||
SensitiveDataLevelBar: "S4",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
CustomLabelLevelBar: "high",
|
||
}
|
||
data := cfg.Data{
|
||
RiskLevel: "none",
|
||
Detail: []cfg.Detail{
|
||
{Type: cfg.CustomLabelType, Level: "none"},
|
||
},
|
||
}
|
||
require.Equal(t, cfg.RiskPass, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, ""))
|
||
})
|
||
|
||
// 测试 customLabel Level=high, threshold=max → 不拦截 (false)
|
||
t.Run("level high threshold max passes", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
RiskAction: "block",
|
||
ContentModerationLevelBar: "max",
|
||
PromptAttackLevelBar: "max",
|
||
SensitiveDataLevelBar: "S4",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
CustomLabelLevelBar: "max",
|
||
}
|
||
data := cfg.Data{
|
||
RiskLevel: "none",
|
||
Detail: []cfg.Detail{
|
||
{Type: cfg.CustomLabelType, Level: "high"},
|
||
},
|
||
}
|
||
require.Equal(t, cfg.RiskPass, cfg.EvaluateRisk(cfg.MultiModalGuard, data, config, ""))
|
||
})
|
||
}
|
||
|
||
func TestCustomLabelConfigIntegration(t *testing.T) {
|
||
test.RunGoTest(t, func(t *testing.T) {
|
||
// 测试 customLabelConfig 配置解析和消费者覆盖
|
||
t.Run("customLabel config with consumer override", func(t *testing.T) {
|
||
host, status := test.NewTestHost(customLabelConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
|
||
securityConfig := config.(*cfg.AISecurityConfig)
|
||
require.Equal(t, "high", securityConfig.CustomLabelLevelBar)
|
||
require.Equal(t, "low", securityConfig.GetCustomLabelLevelBar("exact-user"))
|
||
require.Equal(t, "medium", securityConfig.GetCustomLabelLevelBar("prefix-user"))
|
||
require.Equal(t, "high", securityConfig.GetCustomLabelLevelBar("unknown-user"))
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestRequestMasking(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// 测试请求阶段脱敏成功:riskAction=mask,API 返回 mask 建议,请求体被替换为脱敏内容
|
||
t.Run("request masking success", func(t *testing.T) {
|
||
host, status := test.NewTestHost(maskConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "我的电话是13800138000"}]}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// API 返回 mask 建议,包含脱敏内容
|
||
securityResponse := `{
|
||
"Code": 200,
|
||
"Message": "Success",
|
||
"RequestId": "req-123",
|
||
"Data": {
|
||
"RiskLevel": "low",
|
||
"Detail": [{
|
||
"Suggestion": "mask",
|
||
"Type": "sensitiveData",
|
||
"Level": "S3",
|
||
"Result": [{
|
||
"Label": "phone_number",
|
||
"Confidence": 99.0,
|
||
"Ext": {
|
||
"Desensitization": "我的电话是1**********",
|
||
"SensitiveData": ["13800138000"]
|
||
}
|
||
}]
|
||
}]
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
// 验证请求体被替换为脱敏内容
|
||
processedBody := host.GetRequestBody()
|
||
require.NotNil(t, processedBody)
|
||
content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String()
|
||
require.Equal(t, "我的电话是1**********", content)
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试脱敏内容为空时回退到拦截
|
||
t.Run("empty desensitization falls back to block", func(t *testing.T) {
|
||
host, status := test.NewTestHost(maskConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "我的电话是13800138000"}]}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// API 返回 mask 建议,但 Desensitization 为空
|
||
securityResponse := `{
|
||
"Code": 200,
|
||
"Message": "Success",
|
||
"RequestId": "req-123",
|
||
"Data": {
|
||
"RiskLevel": "low",
|
||
"Detail": [{
|
||
"Suggestion": "mask",
|
||
"Type": "sensitiveData",
|
||
"Level": "S3",
|
||
"Result": [{
|
||
"Label": "phone_number",
|
||
"Confidence": 99.0,
|
||
"Ext": {
|
||
"Desensitization": "",
|
||
"SensitiveData": ["13800138000"]
|
||
}
|
||
}]
|
||
}]
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
// Desensitization 为空时应回退到拦截,SendHttpResponse 被调用
|
||
// 验证请求体未被修改(原始内容保持不变)
|
||
processedBody := host.GetRequestBody()
|
||
require.NotNil(t, processedBody)
|
||
content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String()
|
||
require.Equal(t, "我的电话是13800138000", content)
|
||
})
|
||
|
||
// 测试 riskAction=block 时 mask 建议按现有逻辑处理(向后兼容)
|
||
t.Run("riskAction block keeps existing behavior for mask suggestion", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "我的电话是13800138000"}]}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// API 返回 mask 建议,但全局 riskAction 默认为 block
|
||
securityResponse := `{
|
||
"Code": 200,
|
||
"Message": "Success",
|
||
"RequestId": "req-123",
|
||
"Data": {
|
||
"RiskLevel": "low",
|
||
"Detail": [{
|
||
"Suggestion": "mask",
|
||
"Type": "sensitiveData",
|
||
"Level": "S2",
|
||
"Result": [{
|
||
"Label": "phone_number",
|
||
"Confidence": 99.0,
|
||
"Ext": {
|
||
"Desensitization": "我的电话是1**********",
|
||
"SensitiveData": ["13800138000"]
|
||
}
|
||
}]
|
||
}]
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
// riskAction=block 时,mask 建议按风险等级判断,low 级别应放行
|
||
action = host.GetHttpStreamAction()
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
// 请求体不应被脱敏修改
|
||
processedBody := host.GetRequestBody()
|
||
require.NotNil(t, processedBody)
|
||
content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String()
|
||
require.Equal(t, "我的电话是13800138000", content)
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试 riskAction=mask 时 block 建议优先拦截
|
||
t.Run("block suggestion takes priority over mask", func(t *testing.T) {
|
||
host, status := test.NewTestHost(maskConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "违规内容"}]}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// API 返回 block 建议
|
||
securityResponse := `{
|
||
"Code": 200,
|
||
"Message": "Success",
|
||
"RequestId": "req-123",
|
||
"Data": {
|
||
"RiskLevel": "high",
|
||
"Detail": [{
|
||
"Suggestion": "block",
|
||
"Type": "contentModeration",
|
||
"Level": "high"
|
||
}]
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
// block 建议应拦截请求,请求体不应被修改
|
||
processedBody := host.GetRequestBody()
|
||
require.NotNil(t, processedBody)
|
||
content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String()
|
||
require.Equal(t, "违规内容", content)
|
||
})
|
||
})
|
||
}
|
||
|
||
// 测试配置:MCP + 脱敏模式
|
||
var mcpMaskConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": true,
|
||
"action": "MultiModalGuard",
|
||
"apiType": "mcp",
|
||
"riskAction": "mask",
|
||
"requestContentJsonPath": "params.arguments.input",
|
||
"responseContentJsonPath": "content",
|
||
"responseStreamContentJsonPath": "content",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
func TestIsRiskLevelAcceptable(t *testing.T) {
|
||
// 需要 proxy-wasm host 环境,因为 evaluateRiskMultiModal 调用 proxywasm.LogInfof
|
||
opt := proxytest.NewEmulatorOption().WithVMContext(&types.DefaultVMContext{})
|
||
_, reset := proxytest.NewHostEmulator(opt)
|
||
defer reset()
|
||
|
||
// 用例 1: riskAction=mask, Suggestion=mask → 应返回 true(mask 不应被视为不可接受)
|
||
t.Run("mask action with mask suggestion is acceptable", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
RiskAction: "mask",
|
||
ContentModerationLevelBar: "max",
|
||
PromptAttackLevelBar: "max",
|
||
SensitiveDataLevelBar: "S4",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
CustomLabelLevelBar: "max",
|
||
}
|
||
data := cfg.Data{
|
||
RiskLevel: "none",
|
||
Detail: []cfg.Detail{
|
||
{Suggestion: "mask", Type: "sensitiveData", Level: "S2",
|
||
Result: []cfg.Result{{Ext: cfg.Ext{Desensitization: "masked"}}}},
|
||
},
|
||
}
|
||
require.True(t, cfg.IsRiskLevelAcceptable(cfg.MultiModalGuard, data, config, ""))
|
||
})
|
||
|
||
// 用例 2: riskAction=mask, Suggestion=block → 应返回 false
|
||
t.Run("mask action with block suggestion is not acceptable", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
RiskAction: "mask",
|
||
ContentModerationLevelBar: "max",
|
||
PromptAttackLevelBar: "max",
|
||
SensitiveDataLevelBar: "S4",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
CustomLabelLevelBar: "max",
|
||
}
|
||
data := cfg.Data{
|
||
RiskLevel: "none",
|
||
Detail: []cfg.Detail{
|
||
{Suggestion: "block", Type: "contentModeration", Level: "high"},
|
||
},
|
||
}
|
||
require.True(t, cfg.IsRiskLevelAcceptable(cfg.MultiModalGuard, data, config, ""))
|
||
})
|
||
|
||
// 用例 3: riskAction=mask, 无风险 → 应返回 true
|
||
t.Run("mask action with no risk is acceptable", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
RiskAction: "mask",
|
||
ContentModerationLevelBar: "max",
|
||
PromptAttackLevelBar: "max",
|
||
SensitiveDataLevelBar: "S4",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
CustomLabelLevelBar: "max",
|
||
}
|
||
data := cfg.Data{RiskLevel: "low"}
|
||
require.True(t, cfg.IsRiskLevelAcceptable(cfg.MultiModalGuard, data, config, ""))
|
||
})
|
||
|
||
// 用例 4: riskAction=block, Suggestion=mask, level 未超阈值 → 应返回 true(向后兼容)
|
||
t.Run("block action with mask suggestion below threshold is acceptable", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
RiskAction: "block",
|
||
ContentModerationLevelBar: "max",
|
||
PromptAttackLevelBar: "max",
|
||
SensitiveDataLevelBar: "S4",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
CustomLabelLevelBar: "max",
|
||
}
|
||
data := cfg.Data{
|
||
RiskLevel: "none",
|
||
Detail: []cfg.Detail{
|
||
{Suggestion: "mask", Type: "sensitiveData", Level: "S2"},
|
||
},
|
||
}
|
||
require.True(t, cfg.IsRiskLevelAcceptable(cfg.MultiModalGuard, data, config, ""))
|
||
})
|
||
|
||
// 用例 5: riskAction=block, Suggestion=mask, level 超阈值 → 应返回 false(向后兼容)
|
||
t.Run("block action with mask suggestion exceeding threshold is not acceptable", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
RiskAction: "block",
|
||
ContentModerationLevelBar: "max",
|
||
PromptAttackLevelBar: "max",
|
||
SensitiveDataLevelBar: "S2",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
CustomLabelLevelBar: "max",
|
||
}
|
||
data := cfg.Data{
|
||
RiskLevel: "none",
|
||
Detail: []cfg.Detail{
|
||
{Suggestion: "mask", Type: "sensitiveData", Level: "S2"},
|
||
},
|
||
}
|
||
require.False(t, cfg.IsRiskLevelAcceptable(cfg.MultiModalGuard, data, config, ""))
|
||
})
|
||
|
||
// 用例 6: TextModerationPlus, riskAction=mask → 不受影响
|
||
t.Run("TextModerationPlus not affected by riskAction mask", func(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
RiskAction: "mask",
|
||
RiskLevelBar: "high",
|
||
}
|
||
data := cfg.Data{RiskLevel: "low"}
|
||
require.True(t, cfg.IsRiskLevelAcceptable(cfg.TextModerationPlus, data, config, ""))
|
||
})
|
||
}
|
||
|
||
func TestMcpMaskNotBlock(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// 用例 7: MCP 请求, riskAction=mask, API 返回 Suggestion=mask → 应放行
|
||
t.Run("mcp request with mask suggestion should pass not block", func(t *testing.T) {
|
||
host, status := test.NewTestHost(mcpMaskConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/mcp"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"test","arguments":{"input":"我的电话是13800138000"}}}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// API 返回 mask 建议
|
||
securityResponse := `{
|
||
"Code": 200,
|
||
"Message": "Success",
|
||
"RequestId": "req-123",
|
||
"Data": {
|
||
"RiskLevel": "low",
|
||
"Detail": [{
|
||
"Suggestion": "mask",
|
||
"Type": "sensitiveData",
|
||
"Level": "S2",
|
||
"Result": [{
|
||
"Label": "phone_number",
|
||
"Confidence": 99.0,
|
||
"Ext": {
|
||
"Desensitization": "我的电话是1**********",
|
||
"SensitiveData": ["13800138000"]
|
||
}
|
||
}]
|
||
}]
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
// 应放行而非拦截
|
||
action = host.GetHttpStreamAction()
|
||
require.Equal(t, types.ActionContinue, action)
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 用例 8: MCP 响应, riskAction=mask, API 返回 Suggestion=mask → 应放行
|
||
t.Run("mcp response with mask suggestion should pass not block", func(t *testing.T) {
|
||
host, status := test.NewTestHost(mcpMaskConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/mcp"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
body := `{"content": "我的电话是13800138000"}`
|
||
action := host.CallOnHttpResponseBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// API 返回 mask 建议
|
||
securityResponse := `{
|
||
"Code": 200,
|
||
"Message": "Success",
|
||
"RequestId": "req-123",
|
||
"Data": {
|
||
"RiskLevel": "low",
|
||
"Detail": [{
|
||
"Suggestion": "mask",
|
||
"Type": "sensitiveData",
|
||
"Level": "S2",
|
||
"Result": [{
|
||
"Label": "phone_number",
|
||
"Confidence": 99.0,
|
||
"Ext": {
|
||
"Desensitization": "我的电话是1**********",
|
||
"SensitiveData": ["13800138000"]
|
||
}
|
||
}]
|
||
}]
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
// 应放行而非拦截
|
||
action = host.GetHttpStreamAction()
|
||
require.Equal(t, types.ActionContinue, action)
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 用例 9: MCP 请求, riskAction=mask, API 返回 Suggestion=block → 应拦截
|
||
t.Run("mcp request with block suggestion should deny", func(t *testing.T) {
|
||
host, status := test.NewTestHost(mcpMaskConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/mcp"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"test","arguments":{"input":"违规内容"}}}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// API 返回 block 建议
|
||
securityResponse := `{
|
||
"Code": 200,
|
||
"Message": "Success",
|
||
"RequestId": "req-123",
|
||
"Data": {
|
||
"RiskLevel": "high",
|
||
"Detail": [{
|
||
"Suggestion": "block",
|
||
"Type": "contentModeration",
|
||
"Level": "high"
|
||
}]
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
// block 建议应拦截(SendHttpResponse 被调用,请求不会继续)
|
||
// MCP handler 调用 SendHttpResponse 后不会 resume
|
||
})
|
||
})
|
||
}
|
||
|
||
// =============================================================================
|
||
// TC-PARSE: 配置解析与校验集成测试
|
||
// =============================================================================
|
||
|
||
// 测试配置:MultiModalGuard + 全局维度动作全为合法值
|
||
var dimensionActionValidConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": true,
|
||
"action": "MultiModalGuard",
|
||
"contentModerationAction": "block",
|
||
"promptAttackAction": "block",
|
||
"sensitiveDataAction": "mask",
|
||
"maliciousUrlAction": "block",
|
||
"modelHallucinationAction": "block",
|
||
"customLabelAction": "block",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:MultiModalGuard + 全局维度动作出现非法值
|
||
var dimensionActionInvalidConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": true,
|
||
"action": "MultiModalGuard",
|
||
"contentModerationAction": "allow", // 非法值
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:MultiModalGuard + consumerRiskLevel 内维度动作非法
|
||
var consumerDimensionActionInvalidConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"action": "MultiModalGuard",
|
||
"consumerRiskLevel": []map[string]interface{}{
|
||
{
|
||
"name": "user-a",
|
||
"matchType": "exact",
|
||
"sensitiveDataAction": "deny", // 非法值
|
||
},
|
||
},
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:TextModerationPlus + 配置了维度动作
|
||
var textModPlusDimensionActionConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"action": "TextModerationPlus",
|
||
"sensitiveDataAction": "mask",
|
||
"contentModerationAction": "block",
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:未配置任何动作字段
|
||
var noActionFieldConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// TestTC_PARSE_001 MultiModalGuard + 全局维度动作全为合法值 => 启动成功
|
||
func TestTC_PARSE_001(t *testing.T) {
|
||
test.RunGoTest(t, func(t *testing.T) {
|
||
host, status := test.NewTestHost(dimensionActionValidConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
securityConfig := config.(*cfg.AISecurityConfig)
|
||
require.Equal(t, "block", securityConfig.ContentModerationAction)
|
||
require.Equal(t, "block", securityConfig.PromptAttackAction)
|
||
require.Equal(t, "mask", securityConfig.SensitiveDataAction)
|
||
require.Equal(t, "block", securityConfig.MaliciousUrlAction)
|
||
require.Equal(t, "block", securityConfig.ModelHallucinationAction)
|
||
require.Equal(t, "block", securityConfig.CustomLabelAction)
|
||
})
|
||
}
|
||
|
||
// TestTC_PARSE_002 MultiModalGuard + 全局维度动作出现非法值 => 启动失败
|
||
func TestTC_PARSE_002(t *testing.T) {
|
||
test.RunGoTest(t, func(t *testing.T) {
|
||
host, status := test.NewTestHost(dimensionActionInvalidConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||
})
|
||
}
|
||
|
||
// TestTC_PARSE_003 MultiModalGuard + consumerRiskLevel 内维度动作非法 => 启动失败
|
||
func TestTC_PARSE_003(t *testing.T) {
|
||
test.RunGoTest(t, func(t *testing.T) {
|
||
host, status := test.NewTestHost(consumerDimensionActionInvalidConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||
})
|
||
}
|
||
|
||
// TestTC_PARSE_004 TextModerationPlus + 配置了维度动作 => 启动成功(字段忽略)
|
||
func TestTC_PARSE_004(t *testing.T) {
|
||
test.RunGoTest(t, func(t *testing.T) {
|
||
host, status := test.NewTestHost(textModPlusDimensionActionConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
securityConfig := config.(*cfg.AISecurityConfig)
|
||
// 字段被解析但在运行时被忽略(不影响启动)
|
||
require.Equal(t, "mask", securityConfig.SensitiveDataAction)
|
||
require.Equal(t, "block", securityConfig.ContentModerationAction)
|
||
})
|
||
}
|
||
|
||
// TestTC_PARSE_005 未配置任何动作字段 => 默认 riskAction=block
|
||
func TestTC_PARSE_005(t *testing.T) {
|
||
test.RunGoTest(t, func(t *testing.T) {
|
||
host, status := test.NewTestHost(noActionFieldConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
securityConfig := config.(*cfg.AISecurityConfig)
|
||
require.Equal(t, "block", securityConfig.RiskAction)
|
||
require.Equal(t, "", securityConfig.ContentModerationAction)
|
||
require.Equal(t, "", securityConfig.PromptAttackAction)
|
||
require.Equal(t, "", securityConfig.SensitiveDataAction)
|
||
require.Equal(t, "", securityConfig.MaliciousUrlAction)
|
||
require.Equal(t, "", securityConfig.ModelHallucinationAction)
|
||
require.Equal(t, "", securityConfig.CustomLabelAction)
|
||
})
|
||
}
|
||
|
||
// TestTC_PARSE_006 TextModerationPlus + 非法维度动作值 => 启动成功(需求 8.2)
|
||
func TestTC_PARSE_006(t *testing.T) {
|
||
// 非 MultiModalGuard 下配置了非法维度动作值,不应报错,应启动成功
|
||
invalidDimActionTextModConfig := func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"action": "TextModerationPlus",
|
||
"contentModerationAction": "allow", // 非法值,但非 MultiModalGuard 下应忽略
|
||
"sensitiveDataAction": "deny", // 非法值
|
||
})
|
||
return data
|
||
}()
|
||
test.RunGoTest(t, func(t *testing.T) {
|
||
host, status := test.NewTestHost(invalidDimActionTextModConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
})
|
||
}
|
||
|
||
// TestTC_PARSE_007 TextModerationPlus + consumerRiskLevel 内非法维度动作值 => 启动成功(需求 8.2)
|
||
func TestTC_PARSE_007(t *testing.T) {
|
||
invalidConsumerDimActionTextModConfig := func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"action": "TextModerationPlus",
|
||
"consumerRiskLevel": []map[string]interface{}{
|
||
{
|
||
"name": "user-a",
|
||
"matchType": "exact",
|
||
"sensitiveDataAction": "invalid-value", // 非法值,但非 MultiModalGuard 下应忽略
|
||
},
|
||
},
|
||
})
|
||
return data
|
||
}()
|
||
test.RunGoTest(t, func(t *testing.T) {
|
||
host, status := test.NewTestHost(invalidConsumerDimActionTextModConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
})
|
||
}
|
||
|
||
// =============================================================================
|
||
// TC-REG: 回归测试
|
||
// =============================================================================
|
||
|
||
// 测试配置:历史仅 riskAction=block 的 MultiModalGuard 配置(无维度动作字段)
|
||
var legacyRiskActionBlockConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": true,
|
||
"action": "MultiModalGuard",
|
||
"riskAction": "block",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:历史仅 riskAction=mask 的 MultiModalGuard 配置(无维度动作字段)
|
||
var legacyRiskActionMaskConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": true,
|
||
"action": "MultiModalGuard",
|
||
"riskAction": "mask",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// TestTC_REG_004 历史仅 riskAction 配置的场景不回归
|
||
// 验证:仅配置 riskAction(不配置任何维度动作字段)时,新代码行为与历史一致
|
||
func TestTC_REG_004(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// 子用例 1: riskAction=block,请求安全检查通过(低风险)=> 放行
|
||
t.Run("legacy block config pass on low risk", func(t *testing.T) {
|
||
host, status := test.NewTestHost(legacyRiskActionBlockConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "Hello"}]}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// API 返回低风险,无 Detail 触发
|
||
securityResponse := `{
|
||
"Code": 200,
|
||
"Message": "Success",
|
||
"RequestId": "req-reg-001",
|
||
"Data": {
|
||
"RiskLevel": "none",
|
||
"Detail": []
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
action = host.GetHttpStreamAction()
|
||
require.Equal(t, types.ActionContinue, action)
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 子用例 2: riskAction=block,顶层 RiskLevel 超阈值 => 拦截
|
||
t.Run("legacy block config blocks on high risk level", func(t *testing.T) {
|
||
host, status := test.NewTestHost(legacyRiskActionBlockConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "违规内容"}]}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// API 返回高风险,顶层 RiskLevel 超阈值
|
||
securityResponse := `{
|
||
"Code": 200,
|
||
"Message": "Success",
|
||
"RequestId": "req-reg-002",
|
||
"Data": {
|
||
"RiskLevel": "high",
|
||
"Detail": [{
|
||
"Suggestion": "block",
|
||
"Type": "contentModeration",
|
||
"Level": "high"
|
||
}]
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
// 拦截:请求体不应被修改
|
||
processedBody := host.GetRequestBody()
|
||
require.NotNil(t, processedBody)
|
||
content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String()
|
||
require.Equal(t, "违规内容", content)
|
||
})
|
||
|
||
// 子用例 3: riskAction=block,Detail 有 mask 建议但 level 未超阈值 => 放行(block 模式忽略 mask)
|
||
t.Run("legacy block config ignores mask suggestion below threshold", func(t *testing.T) {
|
||
host, status := test.NewTestHost(legacyRiskActionBlockConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "我的电话是13800138000"}]}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// API 返回 mask 建议,但 riskAction=block 且 level 未超阈值
|
||
securityResponse := `{
|
||
"Code": 200,
|
||
"Message": "Success",
|
||
"RequestId": "req-reg-003",
|
||
"Data": {
|
||
"RiskLevel": "none",
|
||
"Detail": [{
|
||
"Suggestion": "mask",
|
||
"Type": "sensitiveData",
|
||
"Level": "S2",
|
||
"Result": [{
|
||
"Label": "phone_number",
|
||
"Confidence": 99.0,
|
||
"Ext": {
|
||
"Desensitization": "我的电话是1**********",
|
||
"SensitiveData": ["13800138000"]
|
||
}
|
||
}]
|
||
}]
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
// riskAction=block 时,mask 建议不触发脱敏,level 未超阈值应放行
|
||
action = host.GetHttpStreamAction()
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
// 请求体不应被脱敏修改
|
||
processedBody := host.GetRequestBody()
|
||
require.NotNil(t, processedBody)
|
||
content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String()
|
||
require.Equal(t, "我的电话是13800138000", content)
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 子用例 4: riskAction=mask,Detail 有 mask 建议 => 脱敏替换
|
||
t.Run("legacy mask config applies desensitization", func(t *testing.T) {
|
||
host, status := test.NewTestHost(legacyRiskActionMaskConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "我的电话是13800138000"}]}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// API 返回 mask 建议,riskAction=mask 应触发脱敏
|
||
securityResponse := `{
|
||
"Code": 200,
|
||
"Message": "Success",
|
||
"RequestId": "req-reg-004",
|
||
"Data": {
|
||
"RiskLevel": "none",
|
||
"Detail": [{
|
||
"Suggestion": "mask",
|
||
"Type": "sensitiveData",
|
||
"Level": "S3",
|
||
"Result": [{
|
||
"Label": "phone_number",
|
||
"Confidence": 99.0,
|
||
"Ext": {
|
||
"Desensitization": "我的电话是1**********",
|
||
"SensitiveData": ["13800138000"]
|
||
}
|
||
}]
|
||
}]
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
// riskAction=mask 时,mask 建议应触发脱敏替换
|
||
processedBody := host.GetRequestBody()
|
||
require.NotNil(t, processedBody)
|
||
content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String()
|
||
require.Equal(t, "我的电话是1**********", content)
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 子用例 5: riskAction=mask,Detail 有 block 建议 => 仍然拦截(block 优先)
|
||
t.Run("legacy mask config still blocks on block suggestion", func(t *testing.T) {
|
||
host, status := test.NewTestHost(legacyRiskActionMaskConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "违规内容"}]}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// API 返回 block 建议,即使 riskAction=mask 也应拦截
|
||
securityResponse := `{
|
||
"Code": 200,
|
||
"Message": "Success",
|
||
"RequestId": "req-reg-005",
|
||
"Data": {
|
||
"RiskLevel": "high",
|
||
"Detail": [{
|
||
"Suggestion": "block",
|
||
"Type": "contentModeration",
|
||
"Level": "high"
|
||
}]
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
// block 建议应拦截,请求体不应被修改
|
||
processedBody := host.GetRequestBody()
|
||
require.NotNil(t, processedBody)
|
||
content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String()
|
||
require.Equal(t, "违规内容", content)
|
||
})
|
||
})
|
||
}
|
||
func TestMultiModalGuardTextGenerationDeny(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// MultiModalGuard text_generation request deny → exercises multi_modal_guard/text/openai.go BuildDenyResponseBody path
|
||
t.Run("multi modal guard text request deny returns blockedDetails", func(t *testing.T) {
|
||
host, status := test.NewTestHost(multiModalGuardTextConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mmg-text-deny", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local, "expected SendHttpResponse for request deny")
|
||
require.Contains(t, string(local.Data), "blockedDetails")
|
||
})
|
||
|
||
// MultiModalGuard text_generation response deny → exercises common/text/openai.go HandleTextGenerationResponseBody BuildDenyResponseBody path
|
||
t.Run("multi modal guard text response deny returns blockedDetails", func(t *testing.T) {
|
||
host, status := test.NewTestHost(multiModalGuardTextConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
body := `{"choices": [{"message": {"role": "assistant", "content": "bad response content"}}]}`
|
||
action := host.CallOnHttpResponseBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mmg-resp-deny", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local, "expected SendHttpResponse for response deny")
|
||
require.Contains(t, string(local.Data), "blockedDetails")
|
||
})
|
||
|
||
// MultiModalGuard text_generation request pass
|
||
t.Run("multi modal guard text request pass", func(t *testing.T) {
|
||
host, status := test.NewTestHost(multiModalGuardTextConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "Hello"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mmg-pass", "Data": {"RiskLevel": "low"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
action := host.GetHttpStreamAction()
|
||
require.Equal(t, types.ActionContinue, action)
|
||
host.CompleteHttp()
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestMultiModalGuardImageGenerationDeny(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// OpenAI image generation request deny → exercises multi_modal_guard/image/openai.go BuildDenyResponseBody path
|
||
t.Run("openai image request deny returns blockedDetails", func(t *testing.T) {
|
||
host, status := test.NewTestHost(multiModalGuardImageOpenAIConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/images/generations"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"prompt": "generate bad image"}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-openai-deny", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local, "expected SendHttpResponse for OpenAI image request deny")
|
||
require.Contains(t, string(local.Data), "blockedDetails")
|
||
})
|
||
|
||
// OpenAI image generation request pass
|
||
t.Run("openai image request pass", func(t *testing.T) {
|
||
host, status := test.NewTestHost(multiModalGuardImageOpenAIConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/images/generations"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"prompt": "a cute cat"}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-pass", "Data": {"RiskLevel": "low"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
action := host.GetHttpStreamAction()
|
||
require.Equal(t, types.ActionContinue, action)
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// Qwen image generation request deny → exercises multi_modal_guard/image/qwen.go BuildDenyResponseBody path
|
||
t.Run("qwen image request deny returns blockedDetails", func(t *testing.T) {
|
||
host, status := test.NewTestHost(multiModalGuardImageQwenConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/images/generations"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"input": {"prompt": "generate bad image"}}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-qwen-deny", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local, "expected SendHttpResponse for Qwen image request deny")
|
||
require.Contains(t, string(local.Data), "blockedDetails")
|
||
})
|
||
|
||
// Qwen image generation request pass
|
||
t.Run("qwen image request pass", func(t *testing.T) {
|
||
host, status := test.NewTestHost(multiModalGuardImageQwenConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/images/generations"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"input": {"prompt": "a cute cat"}}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-qwen-pass", "Data": {"RiskLevel": "low"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
action := host.GetHttpStreamAction()
|
||
require.Equal(t, types.ActionContinue, action)
|
||
host.CompleteHttp()
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestMCPRequestDeny(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// MCP request deny → exercises multi_modal_guard/mcp/mcp.go HandleMcpRequestBody BuildDenyResponseBody path
|
||
t.Run("mcp request deny returns blockedDetails", func(t *testing.T) {
|
||
host, status := test.NewTestHost(mcpRequestConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/mcp/call"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"method": "tools/call", "params": {"arguments": "bad request content"}}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mcp-deny", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local, "expected SendHttpResponse for MCP request deny")
|
||
require.Contains(t, string(local.Data), "blockedDetails")
|
||
})
|
||
|
||
// MCP request pass
|
||
t.Run("mcp request pass", func(t *testing.T) {
|
||
host, status := test.NewTestHost(mcpRequestConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/mcp/call"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"method": "tools/call", "params": {"arguments": "safe content"}}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mcp-pass", "Data": {"RiskLevel": "low"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
action := host.GetHttpStreamAction()
|
||
require.Equal(t, types.ActionContinue, action)
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// MCP request skip non-tool-call method
|
||
t.Run("mcp request skip non-tool-call", func(t *testing.T) {
|
||
host, status := test.NewTestHost(mcpRequestConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/mcp/call"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"method": "resources/list", "params": {}}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
require.Equal(t, types.ActionContinue, action)
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestTextModerationPlusResponseDeny(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// TextModerationPlus response deny → exercises text_moderation_plus/text (via common/text) BuildDenyResponseBody response path
|
||
t.Run("text moderation plus response deny returns blockedDetails", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withStructuredFormat(basicConfig))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
body := `{"choices": [{"message": {"role": "assistant", "content": "bad response"}}]}`
|
||
action := host.CallOnHttpResponseBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-tmp-resp-deny", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local, "expected SendHttpResponse for response deny")
|
||
require.Contains(t, string(local.Data), "blockedDetails")
|
||
|
||
// Verify OpenAI completion shape wrapper in structured mode:
|
||
// message.content carries only the human-readable deny text and the
|
||
// structured deny payload moves to choices[0].x_higress_guardrail.
|
||
type openAIChatCompletion struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
Guardrail cfg.DenyResponseBody `json:"x_higress_guardrail"`
|
||
} `json:"choices"`
|
||
}
|
||
var outer openAIChatCompletion
|
||
require.NoError(t, json.Unmarshal(local.Data, &outer))
|
||
require.Len(t, outer.Choices, 1)
|
||
|
||
require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Message.Content)
|
||
require.Equal(t, 200, outer.Choices[0].Guardrail.Code)
|
||
require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Guardrail.DenyMessage)
|
||
require.NotEmpty(t, outer.Choices[0].Guardrail.BlockedDetails)
|
||
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists())
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestBuildDenyResponseBody(t *testing.T) {
|
||
makeConfig := func(contentBar, promptBar string) cfg.AISecurityConfig {
|
||
return cfg.AISecurityConfig{
|
||
ContentModerationLevelBar: contentBar,
|
||
PromptAttackLevelBar: promptBar,
|
||
SensitiveDataLevelBar: "S4",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
CustomLabelLevelBar: "max",
|
||
RiskAction: "block",
|
||
Action: cfg.MultiModalGuard,
|
||
}
|
||
}
|
||
|
||
t.Run("code equals response.Code", func(t *testing.T) {
|
||
resp := cfg.Response{
|
||
Code: 200,
|
||
RequestId: "req-123",
|
||
Data: cfg.Data{},
|
||
}
|
||
body, err := cfg.BuildDenyResponseBody(resp, makeConfig("high", "high"), "")
|
||
require.NoError(t, err)
|
||
|
||
var result cfg.DenyResponseBody
|
||
require.NoError(t, json.Unmarshal(body, &result))
|
||
require.Equal(t, 200, result.Code)
|
||
})
|
||
|
||
t.Run("blockedDetails from Data.Detail", func(t *testing.T) {
|
||
resp := cfg.Response{
|
||
Code: 200,
|
||
RequestId: "req-456",
|
||
Data: cfg.Data{
|
||
Detail: []cfg.Detail{
|
||
{Type: cfg.ContentModerationType, Level: "high", Suggestion: "block"},
|
||
{Type: cfg.PromptAttackType, Level: "low", Suggestion: "none"},
|
||
},
|
||
},
|
||
}
|
||
config := makeConfig("high", "high")
|
||
body, err := cfg.BuildDenyResponseBody(resp, config, "")
|
||
require.NoError(t, err)
|
||
|
||
var result cfg.DenyResponseBody
|
||
require.NoError(t, json.Unmarshal(body, &result))
|
||
// only the contentModeration entry meets the "high" bar; promptAttack at "low" does not
|
||
require.Len(t, result.BlockedDetails, 1)
|
||
require.Equal(t, cfg.ContentModerationType, result.BlockedDetails[0].Type)
|
||
require.Equal(t, "high", result.BlockedDetails[0].Level)
|
||
})
|
||
|
||
t.Run("blockedDetails empty when suggestion=block but below threshold", func(t *testing.T) {
|
||
resp := cfg.Response{
|
||
Code: 200,
|
||
RequestId: "req-suggestion-block",
|
||
Data: cfg.Data{
|
||
Detail: []cfg.Detail{
|
||
{Type: cfg.SensitiveDataType, Level: "S3", Suggestion: "block"},
|
||
},
|
||
},
|
||
}
|
||
config := makeConfig("high", "high")
|
||
config.SensitiveDataLevelBar = "S4"
|
||
body, err := cfg.BuildDenyResponseBody(resp, config, "")
|
||
require.NoError(t, err)
|
||
|
||
var result cfg.DenyResponseBody
|
||
require.NoError(t, json.Unmarshal(body, &result))
|
||
require.Len(t, result.BlockedDetails, 0)
|
||
})
|
||
|
||
t.Run("blockedDetails includes customLabel when threshold exceeded", func(t *testing.T) {
|
||
resp := cfg.Response{
|
||
Code: 200,
|
||
RequestId: "req-custom-label",
|
||
Data: cfg.Data{
|
||
Detail: []cfg.Detail{
|
||
{Type: cfg.CustomLabelType, Level: "high", Suggestion: "none"},
|
||
},
|
||
},
|
||
}
|
||
config := makeConfig("high", "high")
|
||
config.CustomLabelLevelBar = "high"
|
||
body, err := cfg.BuildDenyResponseBody(resp, config, "")
|
||
require.NoError(t, err)
|
||
|
||
var result cfg.DenyResponseBody
|
||
require.NoError(t, json.Unmarshal(body, &result))
|
||
require.Len(t, result.BlockedDetails, 1)
|
||
require.Equal(t, cfg.CustomLabelType, result.BlockedDetails[0].Type)
|
||
require.Equal(t, "high", result.BlockedDetails[0].Level)
|
||
})
|
||
|
||
t.Run("blockedDetails fallback from RiskLevel when Detail is empty", func(t *testing.T) {
|
||
resp := cfg.Response{
|
||
Code: 200,
|
||
RequestId: "req-789",
|
||
Data: cfg.Data{
|
||
RiskLevel: "high",
|
||
// Detail deliberately empty
|
||
},
|
||
}
|
||
config := makeConfig("high", "high")
|
||
body, err := cfg.BuildDenyResponseBody(resp, config, "")
|
||
require.NoError(t, err)
|
||
|
||
var result cfg.DenyResponseBody
|
||
require.NoError(t, json.Unmarshal(body, &result))
|
||
require.NotEmpty(t, result.BlockedDetails, "expected fallback detail from RiskLevel")
|
||
require.Equal(t, cfg.ContentModerationType, result.BlockedDetails[0].Type)
|
||
require.Equal(t, "high", result.BlockedDetails[0].Level)
|
||
})
|
||
|
||
t.Run("blockedDetails fallback from AttackLevel when Detail is empty", func(t *testing.T) {
|
||
resp := cfg.Response{
|
||
Code: 200,
|
||
RequestId: "req-abc",
|
||
Data: cfg.Data{
|
||
AttackLevel: "high",
|
||
// Detail deliberately empty
|
||
},
|
||
}
|
||
config := makeConfig("high", "high")
|
||
body, err := cfg.BuildDenyResponseBody(resp, config, "")
|
||
require.NoError(t, err)
|
||
|
||
var result cfg.DenyResponseBody
|
||
require.NoError(t, json.Unmarshal(body, &result))
|
||
require.NotEmpty(t, result.BlockedDetails, "expected fallback detail from AttackLevel")
|
||
require.Equal(t, cfg.PromptAttackType, result.BlockedDetails[0].Type)
|
||
require.Equal(t, "high", result.BlockedDetails[0].Level)
|
||
})
|
||
|
||
t.Run("blockedDetails empty when risk levels below threshold", func(t *testing.T) {
|
||
resp := cfg.Response{
|
||
Code: 200,
|
||
RequestId: "req-def",
|
||
Data: cfg.Data{
|
||
RiskLevel: "low",
|
||
AttackLevel: "low",
|
||
},
|
||
}
|
||
// threshold is "high", so "low" must not produce fallback entries
|
||
config := makeConfig("high", "high")
|
||
body, err := cfg.BuildDenyResponseBody(resp, config, "")
|
||
require.NoError(t, err)
|
||
|
||
var result cfg.DenyResponseBody
|
||
require.NoError(t, json.Unmarshal(body, &result))
|
||
require.Empty(t, result.BlockedDetails)
|
||
})
|
||
}
|
||
|
||
func TestBuildDenyResponseBody_WithDenyMessage(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
ContentModerationLevelBar: "high",
|
||
PromptAttackLevelBar: "high",
|
||
SensitiveDataLevelBar: "S4",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
CustomLabelLevelBar: "max",
|
||
RiskAction: "block",
|
||
Action: cfg.MultiModalGuard,
|
||
DenyMessage: "很抱歉,我无法回答您的问题",
|
||
}
|
||
resp := cfg.Response{
|
||
Code: 200,
|
||
Data: cfg.Data{
|
||
Detail: []cfg.Detail{
|
||
{Type: cfg.ContentModerationType, Level: "high", Suggestion: "block"},
|
||
},
|
||
},
|
||
}
|
||
body, err := cfg.BuildDenyResponseBody(resp, config, "")
|
||
require.NoError(t, err)
|
||
|
||
var result cfg.DenyResponseBody
|
||
require.NoError(t, json.Unmarshal(body, &result))
|
||
require.Equal(t, "很抱歉,我无法回答您的问题", result.DenyMessage)
|
||
}
|
||
|
||
func TestBuildDenyResponseBody_WithoutDenyMessage(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
ContentModerationLevelBar: "high",
|
||
PromptAttackLevelBar: "high",
|
||
SensitiveDataLevelBar: "S4",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
CustomLabelLevelBar: "max",
|
||
RiskAction: "block",
|
||
Action: cfg.MultiModalGuard,
|
||
}
|
||
resp := cfg.Response{
|
||
Code: 200,
|
||
Data: cfg.Data{
|
||
Detail: []cfg.Detail{
|
||
{Type: cfg.ContentModerationType, Level: "high", Suggestion: "block"},
|
||
},
|
||
},
|
||
}
|
||
body, err := cfg.BuildDenyResponseBody(resp, config, "")
|
||
require.NoError(t, err)
|
||
require.NotContains(t, string(body), "denyMessage")
|
||
}
|
||
|
||
func TestBuildDenyResponseBody_BlockedDetailsOnlyTypeAndLevel(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
ContentModerationLevelBar: "high",
|
||
PromptAttackLevelBar: "high",
|
||
SensitiveDataLevelBar: "S4",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
CustomLabelLevelBar: "max",
|
||
RiskAction: "block",
|
||
Action: cfg.MultiModalGuard,
|
||
}
|
||
resp := cfg.Response{
|
||
Code: 200,
|
||
Data: cfg.Data{
|
||
Detail: []cfg.Detail{
|
||
{Type: cfg.ContentModerationType, Level: "high", Suggestion: "block", Result: []cfg.Result{{Label: "violence"}}},
|
||
{Type: cfg.PromptAttackType, Level: "high", Suggestion: "block", Result: []cfg.Result{{Label: "injection"}}},
|
||
},
|
||
},
|
||
}
|
||
body, err := cfg.BuildDenyResponseBody(resp, config, "")
|
||
require.NoError(t, err)
|
||
|
||
var raw map[string]interface{}
|
||
require.NoError(t, json.Unmarshal(body, &raw))
|
||
details := raw["blockedDetails"].([]interface{})
|
||
require.Len(t, details, 2)
|
||
for _, entry := range details {
|
||
m := entry.(map[string]interface{})
|
||
require.Len(t, m, 2, "each blockedDetail entry should have exactly 2 keys (type and level)")
|
||
require.Contains(t, m, "type")
|
||
require.Contains(t, m, "level")
|
||
}
|
||
}
|
||
|
||
func TestBuildDenyResponseBody_CodeField(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
ContentModerationLevelBar: "high",
|
||
PromptAttackLevelBar: "high",
|
||
SensitiveDataLevelBar: "S4",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
CustomLabelLevelBar: "max",
|
||
RiskAction: "block",
|
||
Action: cfg.MultiModalGuard,
|
||
}
|
||
resp := cfg.Response{
|
||
Code: 200,
|
||
Data: cfg.Data{},
|
||
}
|
||
body, err := cfg.BuildDenyResponseBody(resp, config, "")
|
||
require.NoError(t, err)
|
||
|
||
var result cfg.DenyResponseBody
|
||
require.NoError(t, json.Unmarshal(body, &result))
|
||
require.Equal(t, 200, result.Code)
|
||
}
|
||
|
||
func TestBuildDenyResponseBody_NoRequestId(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
ContentModerationLevelBar: "high",
|
||
PromptAttackLevelBar: "high",
|
||
SensitiveDataLevelBar: "S4",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
CustomLabelLevelBar: "max",
|
||
RiskAction: "block",
|
||
Action: cfg.MultiModalGuard,
|
||
}
|
||
resp := cfg.Response{
|
||
Code: 200,
|
||
RequestId: "req-should-not-appear",
|
||
Data: cfg.Data{},
|
||
}
|
||
body, err := cfg.BuildDenyResponseBody(resp, config, "")
|
||
require.NoError(t, err)
|
||
require.NotContains(t, string(body), "requestId")
|
||
}
|
||
|
||
func TestBuildDenyResponseBody_FallbackSynthesis(t *testing.T) {
|
||
config := cfg.AISecurityConfig{
|
||
ContentModerationLevelBar: "high",
|
||
PromptAttackLevelBar: "high",
|
||
SensitiveDataLevelBar: "S4",
|
||
MaliciousUrlLevelBar: "max",
|
||
ModelHallucinationLevelBar: "max",
|
||
CustomLabelLevelBar: "max",
|
||
RiskAction: "block",
|
||
Action: cfg.MultiModalGuard,
|
||
}
|
||
resp := cfg.Response{
|
||
Code: 200,
|
||
Data: cfg.Data{
|
||
RiskLevel: "high",
|
||
// No Detail entries — triggers fallback synthesis
|
||
},
|
||
}
|
||
body, err := cfg.BuildDenyResponseBody(resp, config, "")
|
||
require.NoError(t, err)
|
||
|
||
var raw map[string]interface{}
|
||
require.NoError(t, json.Unmarshal(body, &raw))
|
||
details := raw["blockedDetails"].([]interface{})
|
||
require.NotEmpty(t, details, "expected fallback synthesized entries")
|
||
for _, entry := range details {
|
||
m := entry.(map[string]interface{})
|
||
require.Len(t, m, 2, "fallback blockedDetail entry should have exactly 2 keys (type and level)")
|
||
require.Contains(t, m, "type")
|
||
require.Contains(t, m, "level")
|
||
}
|
||
}
|
||
|
||
// =============================================================================
|
||
// TC-COVER: 覆盖率补充测试
|
||
// =============================================================================
|
||
|
||
// TestMultiModalGuardStreamDeny 覆盖 openai.go RiskBlock 分支中 stream 响应格式路径
|
||
func TestMultiModalGuardStreamDeny(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("stream request deny returns SSE format", func(t *testing.T) {
|
||
host, status := test.NewTestHost(multiModalGuardTextConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// 请求体中包含 stream=true,触发 SSE 响应格式
|
||
body := `{"messages": [{"role": "user", "content": "trigger deny"}], "stream": true}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-stream-deny", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local, "expected SendHttpResponse for stream deny")
|
||
require.Contains(t, string(local.Data), "blockedDetails")
|
||
// 验证 SSE content-type
|
||
foundSSE := false
|
||
for _, h := range local.Headers {
|
||
if h[0] == "content-type" {
|
||
require.Equal(t, "text/event-stream;charset=UTF-8", h[1])
|
||
foundSSE = true
|
||
}
|
||
}
|
||
require.True(t, foundSSE, "expected SSE content-type header")
|
||
})
|
||
})
|
||
}
|
||
|
||
// TestMultiModalGuardProtocolOriginalDeny 覆盖 openai.go RiskBlock 分支中 ProtocolOriginal 路径
|
||
func TestMultiModalGuardProtocolOriginalDeny(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("protocol original deny returns raw blockedDetails JSON", func(t *testing.T) {
|
||
host, status := test.NewTestHost(protocolOriginalConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-proto-orig", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local, "expected SendHttpResponse for protocol original deny")
|
||
require.Contains(t, string(local.Data), "blockedDetails")
|
||
// ProtocolOriginal 直接返回 JSON,不包装 OpenAI 格式
|
||
for _, h := range local.Headers {
|
||
if h[0] == "content-type" {
|
||
require.Equal(t, "application/json", h[1])
|
||
}
|
||
}
|
||
// 响应体是原始 blockedDetails JSON,不含 OpenAI 包装
|
||
require.False(t, gjson.GetBytes(local.Data, "choices").Exists(), "should not wrap in OpenAI format")
|
||
})
|
||
})
|
||
}
|
||
|
||
// TestMultiModalGuardDenyWithAdvice 覆盖 openai.go RiskBlock 分支中 Advice != nil 路径
|
||
func TestMultiModalGuardDenyWithAdvice(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("deny with advice sets riskLabel and riskWords attributes", func(t *testing.T) {
|
||
host, status := test.NewTestHost(multiModalGuardTextConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "trigger deny with advice"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
// 包含 Advice 和 Result 的安全服务响应
|
||
securityResponse := `{
|
||
"Code": 200, "Message": "Success", "RequestId": "req-advice-deny",
|
||
"Data": {
|
||
"RiskLevel": "high",
|
||
"Result": [{"Label": "porn", "RiskWords": "bad-word"}],
|
||
"Advice": [{"Answer": "blocked", "HitLabel": "porn"}],
|
||
"Detail": [{"Suggestion": "block", "Type": "contentModeration", "Level": "high"}]
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local, "expected SendHttpResponse")
|
||
require.Contains(t, string(local.Data), "blockedDetails")
|
||
})
|
||
})
|
||
}
|
||
|
||
// TestMultiChunkMasking 覆盖 openai.go 中 RiskPass + hasMasked 路径
|
||
// 场景:内容超过 LengthLimit(1800),第一 chunk 触发 RiskMask 脱敏替换,第二 chunk RiskPass
|
||
func TestMultiChunkMasking(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("multi chunk masking with pass on second chunk", func(t *testing.T) {
|
||
host, status := test.NewTestHost(maskConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// 生成超过 LengthLimit (1800) 的内容
|
||
longContent := strings.Repeat("a", 2000)
|
||
body := `{"messages": [{"role": "user", "content": "` + longContent + `"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
// 第一个 chunk (1800 chars):返回 mask 建议及脱敏内容
|
||
maskedChunk := strings.Repeat("b", 1800)
|
||
securityResponse1 := `{
|
||
"Code": 200, "Message": "Success", "RequestId": "req-chunk-1",
|
||
"Data": {
|
||
"RiskLevel": "none",
|
||
"Detail": [{
|
||
"Suggestion": "mask", "Type": "sensitiveData", "Level": "S3",
|
||
"Result": [{"Label": "phone", "Confidence": 99.0,
|
||
"Ext": {"Desensitization": "` + maskedChunk + `"}}]
|
||
}]
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse1))
|
||
|
||
// 第二个 chunk (200 chars):返回 pass(无风险)
|
||
securityResponse2 := `{"Code": 200, "Message": "Success", "RequestId": "req-chunk-2", "Data": {"RiskLevel": "none", "Detail": []}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse2))
|
||
|
||
// 验证请求体被替换为脱敏内容
|
||
processedBody := host.GetRequestBody()
|
||
require.NotNil(t, processedBody)
|
||
content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String()
|
||
// 期望 = 脱敏后的第一 chunk (1800 'b') + 原始第二 chunk (200 'a')
|
||
expectedContent := maskedChunk + strings.Repeat("a", 200)
|
||
require.Equal(t, expectedContent, content)
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 覆盖 RiskMask 成功完成路径(单 chunk 内容刚好 <= LengthLimit,RiskMask 后立即完成)
|
||
// 该路径在 RiskMask 分支中 contentIndex >= len(maskedContent) 的子路径
|
||
t.Run("single chunk mask completes in RiskMask branch", func(t *testing.T) {
|
||
host, status := test.NewTestHost(maskConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "我的银行卡号是6222021234567890"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{
|
||
"Code": 200, "Message": "Success", "RequestId": "req-single-mask",
|
||
"Data": {
|
||
"RiskLevel": "none",
|
||
"Detail": [{
|
||
"Suggestion": "mask", "Type": "sensitiveData", "Level": "S3",
|
||
"Result": [{"Label": "bank_card", "Confidence": 99.0,
|
||
"Ext": {"Desensitization": "我的银行卡号是6222************"}}]
|
||
}]
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
processedBody := host.GetRequestBody()
|
||
require.NotNil(t, processedBody)
|
||
content := gjson.GetBytes(processedBody, "messages.@reverse.0.content").String()
|
||
require.Equal(t, "我的银行卡号是6222************", content)
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
})
|
||
}
|
||
|
||
// TestMultiModalGuardMaskStreamDeny 覆盖 openai.go RiskMask 空脱敏 fallthrough 到 RiskBlock 的 stream 路径
|
||
func TestMultiModalGuardMaskStreamDeny(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("mask with empty desensitization falls through to block stream format", func(t *testing.T) {
|
||
host, status := test.NewTestHost(maskConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// stream=true 使 block 走 SSE 格式
|
||
body := `{"messages": [{"role": "user", "content": "敏感内容"}], "stream": true}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
// 返回 mask 建议但脱敏内容为空 → fallthrough 到 RiskBlock
|
||
securityResponse := `{
|
||
"Code": 200, "Message": "Success", "RequestId": "req-mask-stream-deny",
|
||
"Data": {
|
||
"RiskLevel": "none",
|
||
"Detail": [{
|
||
"Suggestion": "mask", "Type": "sensitiveData", "Level": "S3",
|
||
"Result": [{"Label": "phone", "Confidence": 99.0,
|
||
"Ext": {"Desensitization": ""}}]
|
||
}]
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local, "expected SendHttpResponse after mask fallthrough to block")
|
||
// 验证是 SSE 格式
|
||
foundSSE := false
|
||
for _, h := range local.Headers {
|
||
if h[0] == "content-type" {
|
||
require.Equal(t, "text/event-stream;charset=UTF-8", h[1])
|
||
foundSSE = true
|
||
}
|
||
}
|
||
require.True(t, foundSSE, "expected SSE content-type for stream deny")
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestOpenAIDenyLegacyDefaultNonStream(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("default legacy non-stream response keeps deny body in content", func(t *testing.T) {
|
||
host, status := test.NewTestHost(multiModalGuardTextConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-legacy-default", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local)
|
||
|
||
require.Equal(t, "stop", gjson.GetBytes(local.Data, "choices.0.finish_reason").String())
|
||
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").Exists())
|
||
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists())
|
||
|
||
content := gjson.GetBytes(local.Data, "choices.0.message.content").String()
|
||
denyBody := mustDecodeLegacyDenyContent(t, content)
|
||
require.Equal(t, 200, denyBody.Code)
|
||
require.NotEmpty(t, denyBody.BlockedDetails)
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestOpenAIDenyLegacyDefaultStream(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("default legacy stream response keeps deny body in first content frame", func(t *testing.T) {
|
||
host, status := test.NewTestHost(multiModalGuardTextConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "trigger deny"}], "stream": true}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-legacy-stream", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local)
|
||
|
||
raw := string(local.Data)
|
||
require.True(t, strings.HasSuffix(strings.TrimSpace(raw), "data: [DONE]"))
|
||
parts := strings.Split(raw, "\n\n")
|
||
require.GreaterOrEqual(t, len(parts), 3)
|
||
|
||
firstFrame := strings.TrimSpace(strings.TrimPrefix(parts[0], "data:"))
|
||
endFrame := strings.TrimSpace(strings.TrimPrefix(parts[1], "data:"))
|
||
|
||
firstContent := gjson.Get(firstFrame, "choices.0.delta.content").String()
|
||
denyBody := mustDecodeLegacyDenyContent(t, firstContent)
|
||
require.Equal(t, 200, denyBody.Code)
|
||
require.NotEmpty(t, denyBody.BlockedDetails)
|
||
|
||
require.False(t, gjson.Get(firstFrame, "choices.0.x_higress_guardrail").Exists())
|
||
require.False(t, gjson.Get(firstFrame, "choices.0.x_higress").Exists())
|
||
require.False(t, gjson.Get(endFrame, "choices.0.x_higress_guardrail").Exists())
|
||
require.False(t, gjson.Get(endFrame, "choices.0.x_higress").Exists())
|
||
require.Equal(t, "stop", gjson.Get(endFrame, "choices.0.finish_reason").String())
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestOpenAIDenyLegacyDenyCodeKeepsResponseCodeInContent(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("legacy content code remains safecheck response code when denyCode differs", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withConfigOverrides(multiModalGuardTextConfig, map[string]interface{}{
|
||
"denyCode": 451,
|
||
}))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-legacy-451", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local)
|
||
require.Equal(t, uint32(451), local.StatusCode)
|
||
|
||
content := gjson.GetBytes(local.Data, "choices.0.message.content").String()
|
||
denyBody := mustDecodeLegacyDenyContent(t, content)
|
||
require.Equal(t, 200, denyBody.Code)
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestMaskEmptyDesensitizationOpenAIFormats(t *testing.T) {
|
||
securityResponse := `{
|
||
"Code": 200, "Message": "Success", "RequestId": "req-mask-empty-openai",
|
||
"Data": {
|
||
"RiskLevel": "none",
|
||
"Detail": [{
|
||
"Suggestion": "mask", "Type": "sensitiveData", "Level": "S3",
|
||
"Result": [{"Label": "phone", "Confidence": 99.0,
|
||
"Ext": {"Desensitization": ""}}]
|
||
}]
|
||
}
|
||
}`
|
||
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("legacy empty desensitization uses json-stringified deny body", func(t *testing.T) {
|
||
host, status := test.NewTestHost(maskConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "敏感内容"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "application/json"}}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local)
|
||
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").Exists())
|
||
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists())
|
||
|
||
content := gjson.GetBytes(local.Data, "choices.0.message.content").String()
|
||
denyBody := mustDecodeLegacyDenyContent(t, content)
|
||
require.Equal(t, 200, denyBody.Code)
|
||
require.Empty(t, denyBody.BlockedDetails)
|
||
})
|
||
|
||
t.Run("structured empty desensitization uses fallback guardrail", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withStructuredFormat(maskConfig))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "敏感内容"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "application/json"}}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local)
|
||
require.Equal(t, cfg.DefaultDenyMessage, gjson.GetBytes(local.Data, "choices.0.message.content").String())
|
||
require.True(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").IsObject())
|
||
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists())
|
||
require.Equal(t, int64(0), gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail.blockedDetails.#").Int())
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestMaskReplaceJsonFieldFailureOpenAIFormats(t *testing.T) {
|
||
securityResponse := `{
|
||
"Code": 200, "Message": "Success", "RequestId": "req-mask-replace-failure",
|
||
"Data": {
|
||
"RiskLevel": "none",
|
||
"Detail": [{
|
||
"Suggestion": "mask", "Type": "sensitiveData", "Level": "S3",
|
||
"Result": [{"Label": "phone", "Confidence": 99.0,
|
||
"Ext": {"Desensitization": "masked"}}]
|
||
}]
|
||
}
|
||
}`
|
||
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("legacy replace failure keeps pure deny message content", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withConfigOverrides(maskConfig, map[string]interface{}{
|
||
"requestContentJsonPath": "@this.messages.0.content",
|
||
}))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "敏感内容"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "application/json"}}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local)
|
||
require.Equal(t, cfg.DefaultDenyMessage, gjson.GetBytes(local.Data, "choices.0.message.content").String())
|
||
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").Exists())
|
||
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists())
|
||
})
|
||
|
||
t.Run("structured replace failure emits fallback guardrail", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withConfigOverrides(withStructuredFormat(maskConfig), map[string]interface{}{
|
||
"requestContentJsonPath": "@this.messages.0.content",
|
||
}))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "敏感内容"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "application/json"}}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local)
|
||
require.Equal(t, cfg.DefaultDenyMessage, gjson.GetBytes(local.Data, "choices.0.message.content").String())
|
||
require.True(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").IsObject())
|
||
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists())
|
||
require.Equal(t, int64(0), gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail.blockedDetails.#").Int())
|
||
})
|
||
})
|
||
}
|
||
|
||
// =============================================================================
|
||
// x_higress_guardrail 扩展字段链路与模板测试
|
||
// =============================================================================
|
||
|
||
// openAIChoiceWithGuardrail is the minimal OpenAI choice shape used by
|
||
// x_higress_guardrail assertions. Guardrail is unmarshaled directly into the strongly
|
||
// typed cfg.DenyResponseBody so tests assert the documented contract — code
|
||
// is int, denyMessage is string, blockedDetails is a slice — instead of
|
||
// silently tolerating shape drift through map[string]interface{}.
|
||
type openAIChoiceWithGuardrail struct {
|
||
Index int `json:"index"`
|
||
Message struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
Delta struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
} `json:"delta"`
|
||
FinishReason *string `json:"finish_reason"`
|
||
Guardrail cfg.DenyResponseBody `json:"x_higress_guardrail"`
|
||
}
|
||
|
||
// openAIBodyWithGuardrail also carries top-level extension fields so tests can
|
||
// assert the design contract that x_higress_guardrail lives ONLY inside choices[0] —
|
||
// a regression where it leaks to the body root would deserialize into this
|
||
// field and the require.Empty check at the call site would fail.
|
||
type openAIBodyWithGuardrail struct {
|
||
Choices []openAIChoiceWithGuardrail `json:"choices"`
|
||
Guardrail *cfg.DenyResponseBody `json:"x_higress_guardrail,omitempty"`
|
||
XHigress *cfg.DenyResponseBody `json:"x_higress,omitempty"`
|
||
}
|
||
|
||
// TestRequestDenyGuardrailNonStream verifies that 请求阶段非流式 deny renders
|
||
// content as plain text and embeds x_higress_guardrail as a JSON object.
|
||
func TestRequestDenyGuardrailNonStream(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("non-stream request deny carries guardrail object", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withStructuredFormat(multiModalGuardTextConfig))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-x-higress", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local)
|
||
|
||
// x_higress_guardrail must be a JSON object, not a string
|
||
require.True(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").IsObject(),
|
||
"x_higress_guardrail should be an embedded JSON object, not a string literal")
|
||
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists(),
|
||
"structured deny must not emit the old choices[0].x_higress field")
|
||
|
||
var outer openAIBodyWithGuardrail
|
||
require.NoError(t, json.Unmarshal(local.Data, &outer))
|
||
require.Len(t, outer.Choices, 1)
|
||
|
||
require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Message.Content,
|
||
"content should carry only the human-readable deny text")
|
||
require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Guardrail.DenyMessage)
|
||
// A-R1F4 contract: x_higress_guardrail.code carries the gateway-emitted HTTP deny
|
||
// status (config.DenyCode, default 200), NOT the upstream security
|
||
// service's Response.Code. multiModalGuardTextConfig leaves denyCode
|
||
// unset, so this resolves to cfg.DefaultDenyCode.
|
||
require.Equal(t, int(cfg.DefaultDenyCode), outer.Choices[0].Guardrail.Code,
|
||
"x_higress_guardrail.code must equal the gateway-emitted HTTP deny status")
|
||
require.NotNil(t, outer.Choices[0].Guardrail.BlockedDetails)
|
||
|
||
// Design contract: x_higress_guardrail lives ONLY nested under choices[0].
|
||
require.Nil(t, outer.Guardrail,
|
||
"x_higress_guardrail must not leak to body root; only choices[0].x_higress_guardrail is valid")
|
||
require.Nil(t, outer.XHigress, "old x_higress must not leak to body root")
|
||
require.False(t, gjson.GetBytes(local.Data, "x_higress_guardrail").Exists(),
|
||
"x_higress_guardrail must not leak to body root; only choices[0].x_higress_guardrail is valid")
|
||
require.False(t, gjson.GetBytes(local.Data, "x_higress").Exists(),
|
||
"old x_higress must not leak to body root")
|
||
})
|
||
})
|
||
}
|
||
|
||
// TestRequestDenyGuardrailStreamFrames verifies that 请求阶段流式 deny only
|
||
// embeds x_higress_guardrail in the final chunk and the first chunk carries plain text.
|
||
func TestRequestDenyGuardrailStreamFrames(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("stream request deny only attaches guardrail in last chunk", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withStructuredFormat(multiModalGuardTextConfig))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "trigger deny"}], "stream": true}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-stream-x-higress", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local)
|
||
|
||
raw := string(local.Data)
|
||
require.True(t, strings.HasSuffix(strings.TrimSpace(raw), "data: [DONE]"))
|
||
parts := strings.Split(raw, "\n\n")
|
||
// 模板格式: data:<chunk>\n\ndata:<end>\n\ndata: [DONE]
|
||
require.GreaterOrEqual(t, len(parts), 3, "expected at least chunk + end + DONE")
|
||
|
||
firstFrame := strings.TrimPrefix(parts[0], "data:")
|
||
endFrame := strings.TrimPrefix(parts[1], "data:")
|
||
|
||
require.False(t, gjson.Get(firstFrame, "choices.0.x_higress_guardrail").Exists(),
|
||
"first chunk should not carry x_higress_guardrail")
|
||
require.False(t, gjson.Get(firstFrame, "choices.0.x_higress").Exists(),
|
||
"first chunk should not carry old x_higress")
|
||
require.Equal(t, cfg.DefaultDenyMessage,
|
||
gjson.Get(firstFrame, "choices.0.delta.content").String(),
|
||
"first chunk delta.content should be plain text")
|
||
|
||
require.True(t, gjson.Get(endFrame, "choices.0.x_higress_guardrail").IsObject(),
|
||
"final chunk should carry x_higress_guardrail as object")
|
||
require.False(t, gjson.Get(endFrame, "choices.0.x_higress").Exists(),
|
||
"final chunk should not carry old x_higress")
|
||
// Deny stream's terminator carries `stop` for wire-level compatibility
|
||
// with downstream consumers (LangChain / LiteLLM / SDKs / BI) that key
|
||
// off `stop` as a valid completion. The moderation-event signal lives
|
||
// in choices[0].x_higress_guardrail (denyCode / blockedDetails) instead.
|
||
require.Equal(t, "stop", gjson.Get(endFrame, "choices.0.finish_reason").String())
|
||
require.False(t, gjson.Get(endFrame, "choices.0.delta.content").Exists(),
|
||
"final chunk delta should be empty")
|
||
|
||
// Design contract: x_higress_guardrail lives ONLY nested under choices[0].
|
||
require.False(t, gjson.Get(endFrame, "x_higress_guardrail").Exists(),
|
||
"x_higress_guardrail must not leak to body root of the end frame")
|
||
require.False(t, gjson.Get(firstFrame, "x_higress_guardrail").Exists(),
|
||
"x_higress_guardrail must not leak to body root of the first frame")
|
||
require.False(t, gjson.Get(endFrame, "x_higress").Exists(),
|
||
"old x_higress must not leak to body root of the end frame")
|
||
require.False(t, gjson.Get(firstFrame, "x_higress").Exists(),
|
||
"old x_higress must not leak to body root of the first frame")
|
||
|
||
require.Contains(t, parts[2], "[DONE]")
|
||
})
|
||
})
|
||
}
|
||
|
||
// TestRequestDenyDefaultDenyMessage verifies that without a configured
|
||
// denyMessage, both content and x_higress_guardrail.denyMessage fall back to the default
|
||
// text.
|
||
func TestRequestDenyDefaultDenyMessage(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("default deny message used when not configured", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withStructuredFormat(multiModalGuardTextConfig))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-default", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local)
|
||
|
||
content := gjson.GetBytes(local.Data, "choices.0.message.content").String()
|
||
denyMessage := gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail.denyMessage").String()
|
||
require.Equal(t, cfg.DefaultDenyMessage, content)
|
||
require.Equal(t, cfg.DefaultDenyMessage, denyMessage)
|
||
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists())
|
||
})
|
||
})
|
||
}
|
||
|
||
// TestProtocolOriginalDenyShapePreserved guards the regression that the
|
||
// protocol: "original" normal deny path keeps the bare DenyResponseBody shape
|
||
// without OpenAI wrapping or x_higress_guardrail.
|
||
func TestProtocolOriginalDenyShapePreserved(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
for _, tc := range []struct {
|
||
name string
|
||
config json.RawMessage
|
||
}{
|
||
{name: "default format", config: protocolOriginalConfig},
|
||
{name: "legacy format", config: withConfigOverrides(protocolOriginalConfig, map[string]interface{}{"openAIDenyResponseFormat": "legacy"})},
|
||
{name: "structured format", config: withStructuredFormat(protocolOriginalConfig)},
|
||
} {
|
||
tc := tc
|
||
t.Run("original protocol deny stays as bare DenyResponseBody "+tc.name, func(t *testing.T) {
|
||
host, status := test.NewTestHost(tc.config)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-orig-shape", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local)
|
||
|
||
require.False(t, gjson.GetBytes(local.Data, "choices").Exists(),
|
||
"original protocol should not OpenAI-wrap the body")
|
||
require.False(t, gjson.GetBytes(local.Data, "x_higress_guardrail").Exists(),
|
||
"original protocol should not introduce x_higress_guardrail")
|
||
require.False(t, gjson.GetBytes(local.Data, "x_higress").Exists(),
|
||
"original protocol should not introduce old x_higress")
|
||
require.True(t, gjson.GetBytes(local.Data, "blockedDetails").Exists())
|
||
require.Equal(t, int64(200), gjson.GetBytes(local.Data, "code").Int())
|
||
})
|
||
}
|
||
})
|
||
}
|
||
|
||
// TestMaskEmptyDesensitizationOriginalShape guards the A-R1F2 alignment:
|
||
// when riskAction=mask and the upstream returns empty Desensitization under
|
||
// protocol: "original", the response body must be the bare JSON string literal
|
||
// produced by wrapper.MarshalStr(ResolveDenyMessage(config)) — i.e.
|
||
// `"<deny message>"` — mirroring the ReplaceJsonFieldTextContent failure path
|
||
// at lvwang/multi_modal_guard/text/openai.go:102 / :159.
|
||
//
|
||
// Before A-R1F2 this branch fell through to RiskBlock and called
|
||
// BuildDenyResponseBody, returning a structured {code, blockedDetails, ...}
|
||
// object instead. The fallthrough was inconsistent with design Section 5 and
|
||
// has been replaced with the self-handled MarshalStr path; this regression
|
||
// test locks in the new contract so the divergence cannot reappear silently.
|
||
func TestMaskEmptyDesensitizationOriginalShape(t *testing.T) {
|
||
maskOriginalConfig := func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": false,
|
||
"action": "MultiModalGuard",
|
||
"riskAction": "mask",
|
||
"protocol": "original",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("mask empty desensitization under original emits MarshalStr literal", func(t *testing.T) {
|
||
host, status := test.NewTestHost(maskOriginalConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "敏感内容"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
// mask 但脱敏内容空 → 走 A-R1F2 自处理分支
|
||
securityResponse := `{
|
||
"Code": 200, "Message": "Success", "RequestId": "req-mask-empty-orig",
|
||
"Data": {
|
||
"RiskLevel": "none",
|
||
"Detail": [{
|
||
"Suggestion": "mask", "Type": "sensitiveData", "Level": "S3",
|
||
"Result": [{"Label": "phone", "Confidence": 99.0,
|
||
"Ext": {"Desensitization": ""}}]
|
||
}]
|
||
}
|
||
}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local)
|
||
|
||
// 不是 OpenAI 包装
|
||
require.False(t, gjson.GetBytes(local.Data, "choices").Exists(),
|
||
"mask empty-desensitization under original must not be OpenAI-wrapped")
|
||
require.False(t, gjson.GetBytes(local.Data, "object").Exists(),
|
||
"mask empty-desensitization under original must not include OpenAI 'object' field")
|
||
|
||
// 不再是 DenyResponseBody {code, blockedDetails} 结构;
|
||
// 现在是 wrapper.MarshalStr 产物 —— 裸 JSON 字符串字面量
|
||
require.False(t, gjson.GetBytes(local.Data, "code").Exists(),
|
||
"A-R1F2: body should now be a JSON string literal, not a {code, blockedDetails} object")
|
||
require.False(t, gjson.GetBytes(local.Data, "blockedDetails").Exists(),
|
||
"A-R1F2: body should now be a JSON string literal, not a {code, blockedDetails} object")
|
||
|
||
// 实际形态:wrapper.MarshalStr 产物。该 wrapper 在 Higress 的实现里返回
|
||
// 已剥除外层双引号的字符串(见 C-R1F9 备注),所以 body 是原始 deny 文本
|
||
// 字节(不可 json.Unmarshal 回字符串)。
|
||
require.Equal(t, cfg.DefaultDenyMessage, string(local.Data),
|
||
"body should equal raw deny message (wrapper.MarshalStr strips outer quotes — C-R1F9)")
|
||
})
|
||
})
|
||
}
|
||
|
||
// TestResponseStreamingDenyGuardrail drives HandleTextGenerationStreamingResponseBody
|
||
// (lvwang/common/text/openai.go) — the only response-side structured stream
|
||
// writer — and asserts the injected SSE carries:
|
||
// - first chunk: human-readable content, no x_higress_guardrail
|
||
// - end chunk: x_higress_guardrail as a JSON object with code/denyMessage/blockedDetails
|
||
// - terminator: "data: [DONE]"
|
||
func TestResponseStreamingDenyGuardrail(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("response streaming deny injects guardrail only in last frame", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withStructuredFormat(basicConfig))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// Skip request-phase check by using a non-deny request body.
|
||
body := `{"messages": [{"role": "user", "content": "hello"}]}`
|
||
host.CallOnHttpRequestBody([]byte(body))
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(`{"Code": 200, "Message": "Success", "RequestId": "req-stream-resp-pass", "Data": {"RiskLevel": "none"}}`))
|
||
|
||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "text/event-stream"},
|
||
})
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
// Single chunk + end_of_stream=true triggers the security check.
|
||
chunk := []byte("data: {\"choices\":[{\"delta\":{\"content\":\"bad response\"}}]}\n\n")
|
||
host.CallOnHttpStreamingResponseBody(chunk, true)
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-stream-resp-deny", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
injected := host.GetResponseBody()
|
||
require.NotEmpty(t, injected, "expected InjectEncodedDataToFilterChain to deliver SSE deny payload")
|
||
|
||
injectedStr := string(injected)
|
||
require.True(t, strings.HasSuffix(strings.TrimSpace(injectedStr), "data: [DONE]"),
|
||
"expected SSE stream to end with [DONE], got: %s", injectedStr)
|
||
|
||
// Strip "data:" prefixes and split into events.
|
||
events := strings.Split(strings.TrimSpace(injectedStr), "\n\n")
|
||
require.GreaterOrEqual(t, len(events), 2, "expected at least first chunk + end chunk")
|
||
|
||
// First event: content present, no x_higress_guardrail.
|
||
firstPayload := strings.TrimPrefix(events[0], "data:")
|
||
firstPayload = strings.TrimSpace(firstPayload)
|
||
require.Equal(t, cfg.DefaultDenyMessage, gjson.Get(firstPayload, "choices.0.delta.content").String())
|
||
require.False(t, gjson.Get(firstPayload, "choices.0.x_higress_guardrail").Exists(),
|
||
"first chunk must NOT carry x_higress_guardrail")
|
||
require.False(t, gjson.Get(firstPayload, "choices.0.x_higress").Exists(),
|
||
"first chunk must NOT carry old x_higress")
|
||
|
||
// Second event: x_higress_guardrail as JSON object on choices[0].
|
||
secondPayload := strings.TrimPrefix(events[1], "data:")
|
||
secondPayload = strings.TrimSpace(secondPayload)
|
||
guardrail := gjson.Get(secondPayload, "choices.0.x_higress_guardrail")
|
||
require.True(t, guardrail.Exists(), "end chunk must carry x_higress_guardrail")
|
||
require.True(t, guardrail.IsObject(), "x_higress_guardrail must be a JSON object, not a string")
|
||
require.Equal(t, cfg.DefaultDenyMessage, guardrail.Get("denyMessage").String())
|
||
require.True(t, guardrail.Get("blockedDetails").Exists())
|
||
require.False(t, gjson.Get(secondPayload, "choices.0.x_higress").Exists(),
|
||
"end chunk must not carry old x_higress")
|
||
// Streaming deny terminator carries `stop` for wire-level compatibility;
|
||
// moderation signal lives in choices[0].x_higress_guardrail.
|
||
require.Equal(t, "stop", gjson.Get(secondPayload, "choices.0.finish_reason").String())
|
||
|
||
// Design contract: x_higress_guardrail lives ONLY nested under choices[0].
|
||
require.False(t, gjson.Get(secondPayload, "x_higress_guardrail").Exists(),
|
||
"x_higress_guardrail must not leak to body root of the end chunk")
|
||
require.False(t, gjson.Get(firstPayload, "x_higress_guardrail").Exists(),
|
||
"x_higress_guardrail must not leak to body root of the first chunk")
|
||
require.False(t, gjson.Get(secondPayload, "x_higress").Exists(),
|
||
"old x_higress must not leak to body root of the end chunk")
|
||
require.False(t, gjson.Get(firstPayload, "x_higress").Exists(),
|
||
"old x_higress must not leak to body root of the first chunk")
|
||
})
|
||
})
|
||
}
|
||
|
||
// A-R2-16(a): multi_modal_guard 图像审核 deny 通道未被前文 guardrail 测试覆盖,
|
||
// 而 R1-F6 修复(BuildOpenAIFallbackDenyResponseBody err 不再静默)依赖
|
||
// callbackForImage 与 singleCallForImage 调用 BuildOpenAIDenyResponseBody。
|
||
// 本测试发送纯 image_url 请求体直接命中 singleCallForImage 路径,断言图像
|
||
// 审核 deny 同样把 x_higress_guardrail 嵌入 choices[0],与文本 deny 形态对称。
|
||
//
|
||
// 文件:lvwang/multi_modal_guard/text/openai.go:299-369(callbackForImage / OpenAI 包装)
|
||
func TestImageDenyGuardrailShape(t *testing.T) {
|
||
imageCheckConfig := func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkRequestImage": true,
|
||
"action": "MultiModalGuard",
|
||
"apiType": "text_generation",
|
||
"contentModerationLevelBar": "high",
|
||
"promptAttackLevelBar": "high",
|
||
"sensitiveDataLevelBar": "S3",
|
||
"timeout": 2000,
|
||
"bufferLimit": 1000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("multi_modal_guard image deny embeds guardrail on choices[0]", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withStructuredFormat(imageCheckConfig))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// 纯 image_url 内容:content=="",parseContent 跳过文本审核直接走
|
||
// singleCallForImage,让 callbackForImage 渲染 deny。
|
||
body := `{"messages":[{"role":"user","content":[{"type":"image_url","image_url":{"url":"https://example.com/bad.jpg"}}]}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-deny", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local, "expected SendHttpResponse for image deny")
|
||
|
||
// x_higress_guardrail 必须作为对象嵌在 choices[0] 内
|
||
require.True(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").IsObject(),
|
||
"image deny should embed x_higress_guardrail as JSON object on choices[0]")
|
||
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists(),
|
||
"image deny should not emit old x_higress on choices[0]")
|
||
|
||
var outer openAIBodyWithGuardrail
|
||
require.NoError(t, json.Unmarshal(local.Data, &outer))
|
||
require.Len(t, outer.Choices, 1)
|
||
|
||
require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Message.Content,
|
||
"content carries human-readable deny text")
|
||
require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Guardrail.DenyMessage)
|
||
require.Equal(t, int(cfg.DefaultDenyCode), outer.Choices[0].Guardrail.Code,
|
||
"x_higress_guardrail.code = gateway HTTP deny status (A-R1F4 contract)")
|
||
require.NotNil(t, outer.Choices[0].Guardrail.BlockedDetails)
|
||
|
||
// 不可泄漏到 body 根
|
||
require.Nil(t, outer.Guardrail, "x_higress_guardrail must not leak to body root")
|
||
require.Nil(t, outer.XHigress, "old x_higress must not leak to body root")
|
||
require.False(t, gjson.GetBytes(local.Data, "x_higress_guardrail").Exists(),
|
||
"x_higress_guardrail must not leak to body root")
|
||
require.False(t, gjson.GetBytes(local.Data, "x_higress").Exists(),
|
||
"old x_higress must not leak to body root")
|
||
})
|
||
})
|
||
}
|
||
|
||
// A-R2-16(b): text_moderation_plus 请求阶段 deny 通道未被前文 guardrail 测试覆盖。
|
||
// 已有 TestTextModerationPlusResponseDeny 覆盖响应阶段,本测试补齐请求阶段对称
|
||
// 用例,确保 OpenAI 协议下 structured deny body 把 x_higress_guardrail 放在 choices[0]。
|
||
//
|
||
// 文件:lvwang/text_moderation_plus/text/openai.go:56-92(请求阶段 deny 渲染)
|
||
func TestTextModerationPlusRequestDenyGuardrailShape(t *testing.T) {
|
||
tmpRequestConfig := func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"action": "TextModerationPlus",
|
||
"contentModerationLevelBar": "high",
|
||
"timeout": 2000,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("text_moderation_plus request deny embeds guardrail on choices[0]", func(t *testing.T) {
|
||
host, status := test.NewTestHost(withStructuredFormat(tmpRequestConfig))
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
|
||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-tmp-deny", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local, "expected SendHttpResponse for text_moderation_plus request deny")
|
||
|
||
require.True(t, gjson.GetBytes(local.Data, "choices.0.x_higress_guardrail").IsObject(),
|
||
"text_moderation_plus request deny should embed x_higress_guardrail as JSON object on choices[0]")
|
||
require.False(t, gjson.GetBytes(local.Data, "choices.0.x_higress").Exists(),
|
||
"text_moderation_plus request deny should not emit old x_higress")
|
||
|
||
var outer openAIBodyWithGuardrail
|
||
require.NoError(t, json.Unmarshal(local.Data, &outer))
|
||
require.Len(t, outer.Choices, 1)
|
||
|
||
require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Message.Content)
|
||
require.Equal(t, cfg.DefaultDenyMessage, outer.Choices[0].Guardrail.DenyMessage)
|
||
require.Equal(t, int(cfg.DefaultDenyCode), outer.Choices[0].Guardrail.Code,
|
||
"x_higress_guardrail.code = gateway HTTP deny status (A-R1F4 contract)")
|
||
require.NotNil(t, outer.Choices[0].Guardrail.BlockedDetails)
|
||
|
||
require.Nil(t, outer.Guardrail, "x_higress_guardrail must not leak to body root")
|
||
require.Nil(t, outer.XHigress, "old x_higress must not leak to body root")
|
||
require.False(t, gjson.GetBytes(local.Data, "x_higress_guardrail").Exists(),
|
||
"x_higress_guardrail must not leak to body root")
|
||
require.False(t, gjson.GetBytes(local.Data, "x_higress").Exists(),
|
||
"old x_higress must not leak to body root")
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestEmbeddingConfig(t *testing.T) {
|
||
test.RunGoTest(t, func(t *testing.T) {
|
||
t.Run("embedding config with responseErrorContentJsonPath", func(t *testing.T) {
|
||
host, status := test.NewTestHost(embeddingConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
|
||
securityConfig := config.(*cfg.AISecurityConfig)
|
||
require.Equal(t, "embedding", securityConfig.ApiType)
|
||
require.Equal(t, "input", securityConfig.RequestContentJsonPath)
|
||
require.Equal(t, "data", securityConfig.ResponseContentJsonPath)
|
||
require.Equal(t, "error.message", securityConfig.ResponseErrorContentJsonPath)
|
||
require.Equal(t, true, securityConfig.CheckRequest)
|
||
require.Equal(t, true, securityConfig.CheckResponse)
|
||
})
|
||
|
||
t.Run("embedding config without responseErrorContentJsonPath", func(t *testing.T) {
|
||
// Test backward compatibility when responseErrorContentJsonPath is not provided
|
||
configWithoutErrorPath := func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"serviceName": "security-service",
|
||
"servicePort": 8080,
|
||
"serviceHost": "security.example.com",
|
||
"accessKey": "test-ak",
|
||
"secretKey": "test-sk",
|
||
"checkRequest": true,
|
||
"checkResponse": true,
|
||
"action": "MultiModalGuard",
|
||
"apiType": "embedding",
|
||
"requestContentJsonPath": "input",
|
||
"responseContentJsonPath": "data",
|
||
"contentModerationLevelBar": "high",
|
||
"timeout": 2000,
|
||
})
|
||
return data
|
||
}()
|
||
host, status := test.NewTestHost(configWithoutErrorPath)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
|
||
securityConfig := config.(*cfg.AISecurityConfig)
|
||
require.Equal(t, "embedding", securityConfig.ApiType)
|
||
require.Equal(t, "", securityConfig.ResponseErrorContentJsonPath)
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestEmbeddingRequest(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("embedding request with string input pass", func(t *testing.T) {
|
||
host, status := test.NewTestHost(embeddingRequestOnlyConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/embeddings"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"input": "Hello, how are you?", "model": "text-embedding-ada-002"}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-pass", "Data": {"RiskLevel": "low"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
action = host.GetHttpStreamAction()
|
||
require.Equal(t, types.ActionContinue, action)
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
t.Run("embedding request with string array input pass", func(t *testing.T) {
|
||
host, status := test.NewTestHost(embeddingRequestOnlyConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/embeddings"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"input": ["Hello", "World"], "model": "text-embedding-ada-002"}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-array-pass", "Data": {"RiskLevel": "low"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
action = host.GetHttpStreamAction()
|
||
require.Equal(t, types.ActionContinue, action)
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
t.Run("embedding request with token ID array skip", func(t *testing.T) {
|
||
host, status := test.NewTestHost(embeddingRequestOnlyConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/embeddings"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// Token ID array input - should skip detection
|
||
body := `{"input": [1234, 5678, 9012], "model": "text-embedding-ada-002"}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
// Should continue without checking (unsupported input type)
|
||
require.Equal(t, types.ActionContinue, action)
|
||
})
|
||
|
||
t.Run("embedding request deny with embedding error format", func(t *testing.T) {
|
||
host, status := test.NewTestHost(embeddingRequestOnlyConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/embeddings"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
body := `{"input": "bad content", "model": "text-embedding-ada-002"}`
|
||
action := host.CallOnHttpRequestBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-deny", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local, "expected SendHttpResponse for Embedding request deny")
|
||
// Verify the response uses Embedding error format
|
||
var errorResp map[string]interface{}
|
||
require.NoError(t, json.Unmarshal(local.Data, &errorResp))
|
||
require.Contains(t, errorResp, "error")
|
||
errorObj := errorResp["error"].(map[string]interface{})
|
||
require.Contains(t, errorObj, "message")
|
||
require.Contains(t, errorObj, "type")
|
||
require.Contains(t, errorObj, "code")
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestEmbeddingResponse(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("embedding response with error message", func(t *testing.T) {
|
||
host, status := test.NewTestHost(embeddingResponseOnlyConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/embeddings"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
// Response with error.message field
|
||
body := `{"error": {"message": "Rate limit exceeded", "type": "rate_limit_error"}, "data": []}`
|
||
action := host.CallOnHttpResponseBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-resp-error", "Data": {"RiskLevel": "low"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
action = host.GetHttpStreamAction()
|
||
require.Equal(t, types.ActionContinue, action)
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
t.Run("embedding response vector only skip", func(t *testing.T) {
|
||
host, status := test.NewTestHost(embeddingResponseOnlyConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/embeddings"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
// Standard embedding response with only vectors - no text content
|
||
body := `{
|
||
"object": "list",
|
||
"data": [
|
||
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
|
||
{"object": "embedding", "index": 1, "embedding": [0.4, 0.5, 0.6]}
|
||
],
|
||
"model": "text-embedding-ada-002",
|
||
"usage": {"prompt_tokens": 10, "total_tokens": 10}
|
||
}`
|
||
action := host.CallOnHttpResponseBody([]byte(body))
|
||
// Should skip since no text content
|
||
require.Equal(t, types.ActionContinue, action)
|
||
})
|
||
|
||
t.Run("embedding response base64 vector skip", func(t *testing.T) {
|
||
host, status := test.NewTestHost(embeddingResponseOnlyConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/embeddings"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
// Embedding response with base64 encoding_format - embedding is a string, not an array
|
||
body := `{
|
||
"object": "list",
|
||
"data": [
|
||
{"object": "embedding", "index": 0, "embedding": "AGC3PAAAtzzAQLc8gEC3PEBAtzy"}
|
||
],
|
||
"model": "text-embedding-ada-002",
|
||
"usage": {"prompt_tokens": 10, "total_tokens": 10}
|
||
}`
|
||
action := host.CallOnHttpResponseBody([]byte(body))
|
||
// Should skip since base64 embedding strings are not user content
|
||
require.Equal(t, types.ActionContinue, action)
|
||
})
|
||
|
||
t.Run("embedding response deny with embedding error format", func(t *testing.T) {
|
||
host, status := test.NewTestHost(embeddingResponseOnlyConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/embeddings"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
// Response with text content in error.message
|
||
body := `{"error": {"message": "bad response content"}, "data": []}`
|
||
action := host.CallOnHttpResponseBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-resp-deny", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local, "expected SendHttpResponse for Embedding response deny")
|
||
// Verify the response uses Embedding error format
|
||
var errorResp map[string]interface{}
|
||
require.NoError(t, json.Unmarshal(local.Data, &errorResp))
|
||
require.Contains(t, errorResp, "error")
|
||
errorObj := errorResp["error"].(map[string]interface{})
|
||
require.Contains(t, errorObj, "message")
|
||
require.Contains(t, errorObj, "type")
|
||
require.Contains(t, errorObj, "code")
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestEmbeddingStreamingIgnored(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("embedding streaming response ignores responseStreamContentJsonPath", func(t *testing.T) {
|
||
host, status := test.NewTestHost(embeddingConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/embeddings"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// Simulate streaming response headers
|
||
host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
// Even if streaming content path is set, embedding should process non-streaming
|
||
body := `{
|
||
"object": "list",
|
||
"data": [{"object": "embedding", "index": 0, "embedding": [0.1, 0.2]}],
|
||
"model": "text-embedding-ada-002"
|
||
}`
|
||
action := host.CallOnHttpResponseBody([]byte(body))
|
||
// Should continue since there's no text content
|
||
require.Equal(t, types.ActionContinue, action)
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestEmbeddingNon200Response(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("embedding API should check response body for non-200 status", func(t *testing.T) {
|
||
// Embedding API with responseErrorContentJsonPath should check error.message
|
||
// even when status code is not 200
|
||
host, status := test.NewTestHost(embeddingResponseOnlyConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/embeddings"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// Non-200 response (e.g., 400 Bad Request)
|
||
// For embedding API, response body should be buffered for later processing
|
||
// HandleEmbeddingResponseHeaders returns HeaderStopIteration (ActionPause)
|
||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "400"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
// HeaderStopIteration = ActionPause indicates body will be buffered and processed
|
||
require.Equal(t, types.HeaderStopIteration, action)
|
||
|
||
// Response body with error.message should be checked
|
||
body := `{"error": {"message": "Invalid input: sensitive content detected", "type": "invalid_request_error"}}`
|
||
action = host.CallOnHttpResponseBody([]byte(body))
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// Simulate security service response with high risk
|
||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-non200", "Data": {"RiskLevel": "high"}}`
|
||
host.CallOnHttpCall([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
}, []byte(securityResponse))
|
||
|
||
// Verify deny response was sent (Embedding error format)
|
||
local := host.GetLocalResponse()
|
||
require.NotNil(t, local, "expected SendHttpResponse for Embedding deny")
|
||
var errorResp map[string]interface{}
|
||
require.NoError(t, json.Unmarshal(local.Data, &errorResp))
|
||
require.Contains(t, errorResp, "error")
|
||
errorObj := errorResp["error"].(map[string]interface{})
|
||
require.Contains(t, errorObj, "message")
|
||
})
|
||
|
||
t.Run("non-embedding API should skip response body for non-200 status", func(t *testing.T) {
|
||
// Non-embedding API should maintain existing behavior: skip response body
|
||
// for non-200 responses
|
||
host, status := test.NewTestHost(multiModalGuardTextConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
})
|
||
|
||
// Non-200 response
|
||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "500"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
// For non-embedding API, should skip response body check
|
||
require.Equal(t, types.ActionContinue, action)
|
||
})
|
||
})
|
||
}
|