mirror of
https://github.com/alibaba/higress.git
synced 2026-06-03 17:47:25 +08:00
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com> Co-authored-by: rinfx <yucheng.lxr@alibaba-inc.com>
455 lines
17 KiB
Go
455 lines
17 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"testing"
|
|
|
|
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
|
"github.com/higress-group/wasm-go/pkg/iface"
|
|
"github.com/higress-group/wasm-go/pkg/test"
|
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
type aiLogSnapshot struct {
|
|
SafecheckRequests []cfg.GuardrailSubmissionEvent `json:"safecheck_requests"`
|
|
SafecheckRequestIDs []string `json:"safecheck_request_ids"`
|
|
SafecheckRequestID string `json:"safecheck_request_id"`
|
|
SafecheckStatus string `json:"safecheck_status"`
|
|
}
|
|
|
|
func readAILogSnapshot(t *testing.T, host test.TestHost) (aiLogSnapshot, string) {
|
|
t.Helper()
|
|
raw, err := host.GetProperty([]string{wrapper.AILogKey})
|
|
require.NoError(t, err)
|
|
decoded := wrapper.UnmarshalStr(`"` + string(raw) + `"`)
|
|
require.NotEmpty(t, decoded)
|
|
|
|
var snapshot aiLogSnapshot
|
|
require.NoError(t, json.Unmarshal([]byte(decoded), &snapshot))
|
|
return snapshot, decoded
|
|
}
|
|
|
|
func requireAILogArraySchema(t *testing.T, raw string) {
|
|
t.Helper()
|
|
require.True(t, gjson.Get(raw, cfg.SafecheckRequestsKey).IsArray(), "safecheck_requests must be a JSON array")
|
|
require.True(t, gjson.Get(raw, cfg.SafecheckRequestIDsKey).IsArray(), "safecheck_request_ids must be a JSON array")
|
|
}
|
|
|
|
func requireSafecheckEvent(t *testing.T, event cfg.GuardrailSubmissionEvent, phase, modality, result, requestID string) {
|
|
t.Helper()
|
|
require.Equal(t, phase, event.Phase)
|
|
require.Equal(t, modality, event.Modality)
|
|
require.Equal(t, result, event.Result)
|
|
require.Equal(t, requestID, event.RequestID)
|
|
}
|
|
|
|
func TestGuardrailAILogRequestAndResponseEventSchema(t *testing.T) {
|
|
test.RunTest(t, func(t *testing.T) {
|
|
t.Run("request pass emits one structured text event", func(t *testing.T) {
|
|
host, status := test.NewTestHost(multiModalGuardTextConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
host.CallOnHttpRequestHeaders([][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/chat/completions"},
|
|
{":method", "POST"},
|
|
})
|
|
|
|
body := `{"messages": [{"role": "user", "content": "Hello"}]}`
|
|
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
|
|
|
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-structured-pass", "Data": {"RiskLevel": "low"}}`
|
|
host.CallOnHttpCall([][2]string{
|
|
{":status", "200"},
|
|
{"content-type", "application/json"},
|
|
}, []byte(securityResponse))
|
|
|
|
snapshot, raw := readAILogSnapshot(t, host)
|
|
requireAILogArraySchema(t, raw)
|
|
require.Len(t, snapshot.SafecheckRequests, 1)
|
|
requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText, cfg.GuardrailResultPass, "req-structured-pass")
|
|
require.Equal(t, []string{"req-structured-pass"}, snapshot.SafecheckRequestIDs)
|
|
require.Equal(t, "req-structured-pass", snapshot.SafecheckRequestID)
|
|
require.Equal(t, "request pass", snapshot.SafecheckStatus)
|
|
})
|
|
|
|
t.Run("response deny emits one structured text event", func(t *testing.T) {
|
|
host, status := test.NewTestHost(multiModalGuardTextConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
host.CallOnHttpRequestHeaders([][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/chat/completions"},
|
|
{":method", "POST"},
|
|
})
|
|
host.CallOnHttpResponseHeaders([][2]string{
|
|
{":status", "200"},
|
|
{"content-type", "application/json"},
|
|
})
|
|
|
|
body := `{"choices": [{"message": {"role": "assistant", "content": "bad response content"}}]}`
|
|
require.Equal(t, types.ActionPause, host.CallOnHttpResponseBody([]byte(body)))
|
|
|
|
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-structured-deny", "Data": {"RiskLevel": "high"}}`
|
|
host.CallOnHttpCall([][2]string{
|
|
{":status", "200"},
|
|
{"content-type", "application/json"},
|
|
}, []byte(securityResponse))
|
|
|
|
snapshot, raw := readAILogSnapshot(t, host)
|
|
requireAILogArraySchema(t, raw)
|
|
require.Len(t, snapshot.SafecheckRequests, 1)
|
|
requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseResponse, cfg.GuardrailModalityText, cfg.GuardrailResultDeny, "req-structured-deny")
|
|
require.Equal(t, []string{"req-structured-deny"}, snapshot.SafecheckRequestIDs)
|
|
require.Equal(t, "req-structured-deny", snapshot.SafecheckRequestID)
|
|
require.Equal(t, "response deny", snapshot.SafecheckStatus)
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestGuardrailAILogStreamingPassFlushesBeforeEOS(t *testing.T) {
|
|
streamingFlushConfig := func() json.RawMessage {
|
|
data, _ := json.Marshal(map[string]interface{}{
|
|
"serviceName": "security-service",
|
|
"servicePort": 8080,
|
|
"serviceHost": "security.example.com",
|
|
"accessKey": "test-ak",
|
|
"secretKey": "test-sk",
|
|
"checkRequest": false,
|
|
"checkResponse": true,
|
|
"action": "MultiModalGuard",
|
|
"apiType": "text_generation",
|
|
"contentModerationLevelBar": "high",
|
|
"promptAttackLevelBar": "high",
|
|
"sensitiveDataLevelBar": "S3",
|
|
"timeout": 2000,
|
|
"bufferLimit": 1,
|
|
})
|
|
return data
|
|
}()
|
|
|
|
test.RunTest(t, func(t *testing.T) {
|
|
host, status := test.NewTestHost(streamingFlushConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
host.CallOnHttpRequestHeaders([][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/chat/completions"},
|
|
{":method", "POST"},
|
|
})
|
|
host.CallOnHttpResponseHeaders([][2]string{
|
|
{":status", "200"},
|
|
{"content-type", "text/event-stream"},
|
|
})
|
|
|
|
chunk := []byte("data:{\"id\":\"chatcmpl-1\",\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\n")
|
|
host.CallOnHttpStreamingResponseBody(chunk, false)
|
|
|
|
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-stream-pass", "Data": {"RiskLevel": "low"}}`
|
|
host.CallOnHttpCall([][2]string{
|
|
{":status", "200"},
|
|
{"content-type", "application/json"},
|
|
}, []byte(securityResponse))
|
|
|
|
snapshot, raw := readAILogSnapshot(t, host)
|
|
requireAILogArraySchema(t, raw)
|
|
require.Len(t, snapshot.SafecheckRequests, 1)
|
|
requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseResponse, cfg.GuardrailModalityText, cfg.GuardrailResultPass, "req-stream-pass")
|
|
require.False(t, gjson.Get(raw, "safecheck_status").Exists(), "event-level flush should not wait for a terminal safecheck_status")
|
|
})
|
|
}
|
|
|
|
func TestGuardrailAILogErrorFlushAndOrderingForImageSubmissions(t *testing.T) {
|
|
test.RunTest(t, func(t *testing.T) {
|
|
runCase := func(t *testing.T, firstHeaders [][2]string, firstResponse, firstRequestID string) {
|
|
host, status := test.NewTestHost(multiModalGuardImageQwenConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
host.CallOnHttpRequestHeaders([][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/images/generations"},
|
|
{":method", "POST"},
|
|
})
|
|
|
|
body := `{"input": {"images": ["https://example.com/a.png", "https://example.com/b.png"]}}`
|
|
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
|
|
|
host.CallOnHttpCall(firstHeaders, []byte(firstResponse))
|
|
snapshot, raw := readAILogSnapshot(t, host)
|
|
requireAILogArraySchema(t, raw)
|
|
require.Len(t, snapshot.SafecheckRequests, 1)
|
|
requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityImage, cfg.GuardrailResultError, firstRequestID)
|
|
require.Equal(t, []string{firstRequestID}, snapshot.SafecheckRequestIDs)
|
|
|
|
secondResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-image-pass", "Data": {"RiskLevel": "low"}}`
|
|
host.CallOnHttpCall([][2]string{
|
|
{":status", "200"},
|
|
{"content-type", "application/json"},
|
|
}, []byte(secondResponse))
|
|
|
|
snapshot, raw = readAILogSnapshot(t, host)
|
|
requireAILogArraySchema(t, raw)
|
|
require.Len(t, snapshot.SafecheckRequests, 2)
|
|
requireSafecheckEvent(t, snapshot.SafecheckRequests[1], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityImage, cfg.GuardrailResultPass, "req-image-pass")
|
|
require.Equal(t, []string{firstRequestID, "req-image-pass"}, snapshot.SafecheckRequestIDs)
|
|
require.Equal(t, "req-image-pass", snapshot.SafecheckRequestID)
|
|
}
|
|
|
|
t.Run("non-200 HTTP response flushes error before next image submission", func(t *testing.T) {
|
|
runCase(t, [][2]string{
|
|
{":status", "502"},
|
|
{"content-type", "application/json"},
|
|
}, `{"RequestId": "req-http-error"}`, "req-http-error")
|
|
})
|
|
|
|
t.Run("business failure flushes error before next image submission", func(t *testing.T) {
|
|
runCase(t, [][2]string{
|
|
{":status", "200"},
|
|
{"content-type", "application/json"},
|
|
}, `{"Code": 500, "Message": "Failed", "RequestId": "req-business-error"}`, "req-business-error")
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestGuardrailAILogMalformedRequestIDsAreIgnored(t *testing.T) {
|
|
test.RunTest(t, func(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
response string
|
|
expectedResult string
|
|
}{
|
|
{
|
|
name: "missing",
|
|
response: `{"Code": 200, "Message": "Success", "Data": {"RiskLevel": "low"}}`,
|
|
expectedResult: cfg.GuardrailResultPass,
|
|
},
|
|
{
|
|
name: "empty",
|
|
response: `{"Code": 200, "Message": "Success", "RequestId": "", "Data": {"RiskLevel": "low"}}`,
|
|
expectedResult: cfg.GuardrailResultPass,
|
|
},
|
|
{
|
|
name: "whitespace",
|
|
response: `{"Code": 200, "Message": "Success", "RequestId": " ", "Data": {"RiskLevel": "low"}}`,
|
|
expectedResult: cfg.GuardrailResultPass,
|
|
},
|
|
{
|
|
name: "non-string",
|
|
response: `{"Code": 200, "Message": "Success", "RequestId": 123, "Data": {"RiskLevel": "low"}}`,
|
|
expectedResult: cfg.GuardrailResultError,
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
host, status := test.NewTestHost(multiModalGuardTextConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
host.CallOnHttpRequestHeaders([][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/chat/completions"},
|
|
{":method", "POST"},
|
|
})
|
|
|
|
body := `{"messages": [{"role": "user", "content": "Hello"}]}`
|
|
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
|
|
|
host.CallOnHttpCall([][2]string{
|
|
{":status", "200"},
|
|
{"content-type", "application/json"},
|
|
}, []byte(tc.response))
|
|
|
|
snapshot, raw := readAILogSnapshot(t, host)
|
|
requireAILogArraySchema(t, raw)
|
|
require.Len(t, snapshot.SafecheckRequests, 1)
|
|
requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText, tc.expectedResult, "")
|
|
require.Empty(t, snapshot.SafecheckRequestIDs)
|
|
require.False(t, gjson.Get(raw, cfg.SafecheckRequestIDKey).Exists())
|
|
require.False(t, gjson.Get(raw, cfg.SafecheckRequestsKey+".0.requestId").Exists())
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestGuardrailAILogMaskFallbackRecordsDeny(t *testing.T) {
|
|
test.RunTest(t, func(t *testing.T) {
|
|
host, status := test.NewTestHost(maskConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
host.CallOnHttpRequestHeaders([][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/chat/completions"},
|
|
{":method", "POST"},
|
|
})
|
|
|
|
body := `{"messages": [{"role": "user", "content": "敏感内容"}]}`
|
|
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
|
|
|
securityResponse := `{
|
|
"Code": 200, "Message": "Success", "RequestId": "req-mask-fallback",
|
|
"Data": {
|
|
"RiskLevel": "none",
|
|
"Detail": [{
|
|
"Suggestion": "mask", "Type": "sensitiveData", "Level": "S3",
|
|
"Result": [{"Label": "phone", "Confidence": 99.0,
|
|
"Ext": {"Desensitization": ""}}]
|
|
}]
|
|
}
|
|
}`
|
|
host.CallOnHttpCall([][2]string{
|
|
{":status", "200"},
|
|
{"content-type", "application/json"},
|
|
}, []byte(securityResponse))
|
|
|
|
snapshot, raw := readAILogSnapshot(t, host)
|
|
requireAILogArraySchema(t, raw)
|
|
require.Len(t, snapshot.SafecheckRequests, 1)
|
|
requireSafecheckEvent(t, snapshot.SafecheckRequests[0], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText, cfg.GuardrailResultDeny, "req-mask-fallback")
|
|
require.Equal(t, []string{"req-mask-fallback"}, snapshot.SafecheckRequestIDs)
|
|
})
|
|
}
|
|
|
|
func TestGuardrailAILogDispatchFailureEmitsErrorEvent(t *testing.T) {
|
|
ctx := newStubHTTPContext()
|
|
eventIndex := cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText)
|
|
cfg.CompleteGuardrailSubmissionEventWithRequestID(ctx, eventIndex, "", cfg.GuardrailResultError)
|
|
cfg.WriteGuardrailLog(ctx)
|
|
|
|
events, ok := ctx.GetUserAttribute(cfg.SafecheckRequestsKey).([]cfg.GuardrailSubmissionEvent)
|
|
require.True(t, ok)
|
|
require.Len(t, events, 1)
|
|
requireSafecheckEvent(t, events[0], cfg.GuardrailPhaseRequest, cfg.GuardrailModalityText, cfg.GuardrailResultError, "")
|
|
|
|
requestIDs, ok := ctx.GetUserAttribute(cfg.SafecheckRequestIDsKey).([]string)
|
|
require.True(t, ok)
|
|
require.Empty(t, requestIDs)
|
|
require.Nil(t, ctx.GetUserAttribute(cfg.SafecheckRequestIDKey))
|
|
require.Equal(t, []string{wrapper.AILogKey}, ctx.writes)
|
|
}
|
|
|
|
type stubHTTPContext struct {
|
|
userContext map[string]interface{}
|
|
userAttribute map[string]interface{}
|
|
bufferQueue [][]byte
|
|
writes []string
|
|
routeCallError error
|
|
}
|
|
|
|
func newStubHTTPContext() *stubHTTPContext {
|
|
return &stubHTTPContext{
|
|
userContext: map[string]interface{}{},
|
|
userAttribute: map[string]interface{}{},
|
|
}
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) Scheme() string { return "" }
|
|
func (ctx *stubHTTPContext) Host() string { return "" }
|
|
func (ctx *stubHTTPContext) Path() string { return "" }
|
|
func (ctx *stubHTTPContext) Method() string { return "" }
|
|
|
|
func (ctx *stubHTTPContext) SetContext(key string, value interface{}) {
|
|
ctx.userContext[key] = value
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) GetContext(key string) interface{} {
|
|
return ctx.userContext[key]
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) GetBoolContext(key string, defaultValue bool) bool {
|
|
if value, ok := ctx.userContext[key].(bool); ok {
|
|
return value
|
|
}
|
|
return defaultValue
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) GetStringContext(key, defaultValue string) string {
|
|
if value, ok := ctx.userContext[key].(string); ok {
|
|
return value
|
|
}
|
|
return defaultValue
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) GetByteSliceContext(key string, defaultValue []byte) []byte {
|
|
if value, ok := ctx.userContext[key].([]byte); ok {
|
|
return value
|
|
}
|
|
return defaultValue
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) GetUserAttribute(key string) interface{} {
|
|
return ctx.userAttribute[key]
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) SetUserAttribute(key string, value interface{}) {
|
|
ctx.userAttribute[key] = value
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) SetUserAttributeMap(kvmap map[string]interface{}) {
|
|
ctx.userAttribute = kvmap
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) GetUserAttributeMap() map[string]interface{} {
|
|
return ctx.userAttribute
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) WriteUserAttributeToLog() error {
|
|
return ctx.WriteUserAttributeToLogWithKey(wrapper.CustomLogKey)
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) WriteUserAttributeToLogWithKey(key string) error {
|
|
ctx.writes = append(ctx.writes, key)
|
|
return nil
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) WriteUserAttributeToTrace() error { return nil }
|
|
func (ctx *stubHTTPContext) DontReadRequestBody() {}
|
|
func (ctx *stubHTTPContext) DontReadResponseBody() {}
|
|
func (ctx *stubHTTPContext) BufferRequestBody() {}
|
|
func (ctx *stubHTTPContext) BufferResponseBody() {}
|
|
func (ctx *stubHTTPContext) NeedPauseStreamingResponse() {}
|
|
|
|
func (ctx *stubHTTPContext) PushBuffer(buffer []byte) {
|
|
ctx.bufferQueue = append(ctx.bufferQueue, buffer)
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) PopBuffer() []byte {
|
|
if len(ctx.bufferQueue) == 0 {
|
|
return nil
|
|
}
|
|
buffer := ctx.bufferQueue[0]
|
|
ctx.bufferQueue = ctx.bufferQueue[1:]
|
|
return buffer
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) BufferQueueSize() int { return len(ctx.bufferQueue) }
|
|
func (ctx *stubHTTPContext) DisableReroute() {}
|
|
func (ctx *stubHTTPContext) SetRequestBodyBufferLimit(uint32) {
|
|
}
|
|
func (ctx *stubHTTPContext) SetResponseBodyBufferLimit(uint32) {
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) RouteCall(string, string, [][2]string, []byte, iface.RouteResponseCallback) error {
|
|
return ctx.routeCallError
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) GetExecutionPhase() iface.HTTPExecutionPhase {
|
|
return iface.DecodeData
|
|
}
|
|
|
|
func (ctx *stubHTTPContext) HasRequestBody() bool { return true }
|
|
func (ctx *stubHTTPContext) HasResponseBody() bool { return true }
|
|
func (ctx *stubHTTPContext) IsWebsocket() bool { return false }
|
|
func (ctx *stubHTTPContext) IsBinaryRequestBody() bool { return false }
|
|
func (ctx *stubHTTPContext) IsBinaryResponseBody() bool {
|
|
return false
|
|
}
|