feat(ai-security-guard): add fallback JSON paths for response content extraction (#3738)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: rinfx <yucheng.lxr@alibaba-inc.com>
This commit is contained in:
JianweiWang
2026-04-28 14:58:59 +08:00
committed by GitHub
parent 1d33067372
commit 5173b4b2b8
6 changed files with 896 additions and 32 deletions

View File

@@ -29,6 +29,8 @@ description: 阿里云内容安全检测
| `requestContentJsonPath` | string | optional | `messages.@reverse.0.content` | 指定要检测内容在请求body中的jsonpath |
| `responseContentJsonPath` | string | optional | `choices.0.message.content` | 指定要检测内容在响应body中的jsonpath |
| `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | 指定要检测内容在流式响应body中的jsonpath |
| `responseContentFallbackJsonPaths` | array | optional | [`choices.0.message.content`, `content.#(type=="text")#.text`] | 当 `responseContentJsonPath` 提取为空时,按顺序尝试这些兜底路径;与主路径相同的项会自动跳过;显式配置为空数组 `[]` 可禁用兜底 |
| `responseStreamContentFallbackJsonPaths` | array | optional | [`choices.0.delta.content`, `delta.text`] | 当 `responseStreamContentJsonPath` 提取为空时,按顺序尝试这些流式兜底路径;与主路径相同的项会自动跳过;显式配置为空数组 `[]` 可禁用兜底 |
| `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 |
| `denyMessage` | string | optional | openai格式的流式/非流式响应 | 指定内容非法时的响应内容 |
| `protocol` | string | optional | openai | 协议格式非openai协议填`original` |
@@ -211,6 +213,34 @@ denyMessage: "很抱歉,我无法回答您的问题"
protocol: original
```
### 配置响应内容兜底提取路径
当主路径提取不到内容时,可按优先级顺序配置兜底路径,兼容多种返回协议:
```yaml
serviceName: safecheck.dns
servicePort: 443
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkResponse: true
responseContentJsonPath: "choices.0.message.content"
responseStreamContentJsonPath: "choices.0.delta.content"
responseContentFallbackJsonPaths:
- "output.text"
- 'content.#(type=="text")#.text'
responseStreamContentFallbackJsonPaths:
- "payload.delta"
- "delta.text"
```
如需严格模式(主路径未命中即跳过,不走兜底),可显式关闭兜底:
```yaml
responseContentFallbackJsonPaths: []
responseStreamContentFallbackJsonPaths: []
```
## 可观测
### Metric
ai-security-guard 插件提供了以下监控指标:

View File

@@ -29,6 +29,8 @@ Plugin Priority: `300`
| `requestContentJsonPath` | string | optional | `messages.@reverse.0.content` | Specify the jsonpath of the content to be detected in the request body |
| `responseContentJsonPath` | string | optional | `choices.0.message.content` | Specify the jsonpath of the content to be detected in the response body |
| `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | Specify the jsonpath of the content to be detected in the streaming response body |
| `responseContentFallbackJsonPaths` | array | optional | [`choices.0.message.content`, `content.#(type=="text")#.text`] | Fallback paths tried in order when `responseContentJsonPath` extracts empty content; entries equal to the primary path are skipped automatically; set to `[]` to disable fallback explicitly |
| `responseStreamContentFallbackJsonPaths` | array | optional | [`choices.0.delta.content`, `delta.text`] | Streaming fallback paths tried in order when `responseStreamContentJsonPath` extracts empty content; entries equal to the primary path are skipped automatically; set to `[]` to disable fallback explicitly |
| `denyCode` | int | optional | 200 | Response status code when the specified content is illegal |
| `denyMessage` | string | optional | Drainage/non-streaming response in openai format, the answer content is the suggested answer from Alibaba Cloud content security | Response content when the specified content is illegal |
| `protocol` | string | optional | openai | protocol format, `openai` or `original` |
@@ -129,6 +131,34 @@ checkRequest: true
checkResponse: true
```
### Configure response fallback extraction paths
When primary extraction paths are empty, you can configure ordered fallback paths to support multiple response formats:
```yaml
serviceName: safecheck.dns
servicePort: 443
serviceHost: green-cip.cn-shanghai.aliyuncs.com
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkResponse: true
responseContentJsonPath: "choices.0.message.content"
responseStreamContentJsonPath: "choices.0.delta.content"
responseContentFallbackJsonPaths:
- "output.text"
- 'content.#(type=="text")#.text'
responseStreamContentFallbackJsonPaths:
- "payload.delta"
- "delta.text"
```
To enforce strict mode (no fallback), configure both fields as empty arrays:
```yaml
responseContentFallbackJsonPaths: []
responseStreamContentFallbackJsonPaths: []
```
## Observability
### Metric
ai-security-guard plugin provides following metrics:

View File

@@ -67,6 +67,26 @@ const (
DefaultTextModerationPlusTextOutputCheckService = "llm_response_moderation"
)
var (
// Keep these defaults aligned with previous hardcoded fallback extraction behavior.
defaultResponseFallbackJsonPaths = []string{
"choices.0.message.content",
`content.#(type=="text")#.text`,
}
defaultStreamingResponseFallbackJsonPaths = []string{
"choices.0.delta.content",
"delta.text",
}
)
func DefaultResponseFallbackJsonPaths() []string {
return append([]string(nil), defaultResponseFallbackJsonPaths...)
}
func DefaultStreamingResponseFallbackJsonPaths() []string {
return append([]string(nil), defaultStreamingResponseFallbackJsonPaths...)
}
// api types
const (
@@ -143,38 +163,40 @@ func (m *Matcher) match(consumer string) bool {
}
type AISecurityConfig struct {
Client wrapper.HttpClient
Host string
AK string
SK string
Token string
Action string
CheckRequest bool
CheckRequestImage bool
RequestCheckService string
RequestImageCheckService string
RequestContentJsonPath string
CheckResponse bool
ResponseCheckService string
ResponseImageCheckService string
ResponseContentJsonPath string
ResponseStreamContentJsonPath string
DenyCode int64
DenyMessage string
ProtocolOriginal bool
RiskLevelBar string
ContentModerationLevelBar string
PromptAttackLevelBar string
SensitiveDataLevelBar string
MaliciousUrlLevelBar string
ModelHallucinationLevelBar string
CustomLabelLevelBar string
Timeout uint32
BufferLimit int
Metrics map[string]proxywasm.MetricCounter
ConsumerRequestCheckService []map[string]interface{}
ConsumerResponseCheckService []map[string]interface{}
ConsumerRiskLevel []map[string]interface{}
Client wrapper.HttpClient
Host string
AK string
SK string
Token string
Action string
CheckRequest bool
CheckRequestImage bool
RequestCheckService string
RequestImageCheckService string
RequestContentJsonPath string
CheckResponse bool
ResponseCheckService string
ResponseImageCheckService string
ResponseContentJsonPath string
ResponseStreamContentJsonPath string
ResponseContentFallbackJsonPaths []string
ResponseStreamContentFallbackJsonPaths []string
DenyCode int64
DenyMessage string
ProtocolOriginal bool
RiskLevelBar string
ContentModerationLevelBar string
PromptAttackLevelBar string
SensitiveDataLevelBar string
MaliciousUrlLevelBar string
ModelHallucinationLevelBar string
CustomLabelLevelBar string
Timeout uint32
BufferLimit int
Metrics map[string]proxywasm.MetricCounter
ConsumerRequestCheckService []map[string]interface{}
ConsumerResponseCheckService []map[string]interface{}
ConsumerRiskLevel []map[string]interface{}
// text_generation, image_generation, etc.
ApiType string
// openai, qwen, comfyui, etc.
@@ -287,6 +309,16 @@ func (config *AISecurityConfig) Parse(json gjson.Result) error {
if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() {
config.ResponseStreamContentJsonPath = obj.String()
}
if paths, exists, err := parseOptionalStringArrayConfig(json, "responseContentFallbackJsonPaths"); err != nil {
return err
} else if exists {
config.ResponseContentFallbackJsonPaths = paths
}
if paths, exists, err := parseOptionalStringArrayConfig(json, "responseStreamContentFallbackJsonPaths"); err != nil {
return err
} else if exists {
config.ResponseStreamContentFallbackJsonPaths = paths
}
if obj := json.Get("contentModerationLevelBar"); obj.Exists() {
config.ContentModerationLevelBar = obj.String()
if LevelToInt(config.ContentModerationLevelBar) <= 0 {
@@ -448,6 +480,29 @@ func parseDimensionAction(json gjson.Result, fieldName string) (string, error) {
return "", nil
}
func parseOptionalStringArrayConfig(json gjson.Result, fieldName string) ([]string, bool, error) {
obj := json.Get(fieldName)
if !obj.Exists() {
return nil, false, nil
}
if !obj.IsArray() {
return nil, true, fmt.Errorf("invalid %s, value must be an array of non-empty strings", fieldName)
}
items := obj.Array()
paths := make([]string, 0, len(items))
for _, item := range items {
if item.Type != gjson.String {
return nil, true, fmt.Errorf("invalid %s, value must be an array of non-empty strings", fieldName)
}
path := strings.TrimSpace(item.String())
if path == "" {
return nil, true, fmt.Errorf("invalid %s, value must be an array of non-empty strings", fieldName)
}
paths = append(paths, path)
}
return paths, true, nil
}
func (config *AISecurityConfig) SetDefaultValues() {
switch config.Action {
case TextModerationPlus:
@@ -463,6 +518,8 @@ func (config *AISecurityConfig) SetDefaultValues() {
config.RequestContentJsonPath = DefaultRequestJsonPath
config.ResponseContentJsonPath = DefaultResponseJsonPath
config.ResponseStreamContentJsonPath = DefaultStreamingResponseJsonPath
config.ResponseContentFallbackJsonPaths = DefaultResponseFallbackJsonPaths()
config.ResponseStreamContentFallbackJsonPaths = DefaultStreamingResponseFallbackJsonPaths()
config.ContentModerationLevelBar = MaxRisk
config.PromptAttackLevelBar = MaxRisk
config.SensitiveDataLevelBar = S4Sensitive

View File

@@ -18,11 +18,18 @@ import (
"github.com/tidwall/gjson"
)
const (
responseFallbackPathsCtxKey = "response_fallback_paths"
responseStreamFallbackPathsCtxKey = "response_stream_fallback_paths"
)
func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
ctx.SetContext("end_of_stream_received", false)
ctx.SetContext("during_call", false)
ctx.SetContext("risk_detected", false)
ctx.SetContext(responseFallbackPathsCtxKey, buildEffectiveFallbackPaths(config.ResponseContentJsonPath, config.ResponseContentFallbackJsonPaths))
ctx.SetContext(responseStreamFallbackPathsCtxKey, buildEffectiveFallbackPaths(config.ResponseStreamContentJsonPath, config.ResponseStreamContentFallbackJsonPaths))
sessionID, _ := utils.GenerateHexID(20)
ctx.SetContext("sessionID", sessionID)
if strings.Contains(contentType, "text/event-stream") {
@@ -36,6 +43,7 @@ func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISe
func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
consumer, _ := ctx.GetContext("consumer").(string)
streamFallbackPaths := getEffectiveFallbackPathsFromContext(ctx, responseStreamFallbackPathsCtxKey, config.ResponseStreamContentJsonPath, config.ResponseStreamContentFallbackJsonPaths)
var sessionID string
if ctx.GetContext("sessionID") == nil {
sessionID, _ = utils.GenerateHexID(20)
@@ -101,6 +109,9 @@ func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config c
front := ctx.PopBuffer()
bufferQueue = append(bufferQueue, front)
msg := gjson.GetBytes(front, config.ResponseStreamContentJsonPath).String()
if len(msg) == 0 {
msg = autoExtractStreamingResponseContent(front, streamFallbackPaths)
}
buffer += msg
if len([]rune(buffer)) >= config.BufferLimit {
break
@@ -162,6 +173,8 @@ func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config c
func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
responseFallbackPaths := getEffectiveFallbackPathsFromContext(ctx, responseFallbackPathsCtxKey, config.ResponseContentJsonPath, config.ResponseContentFallbackJsonPaths)
streamFallbackPaths := getEffectiveFallbackPathsFromContext(ctx, responseStreamFallbackPathsCtxKey, config.ResponseStreamContentJsonPath, config.ResponseStreamContentFallbackJsonPaths)
log.Debugf("checking response body...")
startTime := time.Now().UnixMilli()
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
@@ -169,8 +182,14 @@ func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecu
var content string
if isStreamingResponse {
content = utils.ExtractMessageFromStreamingBody(body, config.ResponseStreamContentJsonPath)
if len(content) == 0 {
content = autoExtractStreamingResponseFromSSE(body, streamFallbackPaths)
}
} else {
content = gjson.GetBytes(body, config.ResponseContentJsonPath).String()
if len(content) == 0 {
content = autoExtractResponseContent(body, responseFallbackPaths)
}
}
log.Debugf("Raw response content is: %s", content)
if len(content) == 0 {
@@ -255,3 +274,148 @@ func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecu
singleCall()
return types.ActionPause
}
// autoExtractResponseContent tries configured fallback paths to extract text content.
func autoExtractResponseContent(body []byte, fallbackPaths []string) string {
if len(fallbackPaths) == 0 {
return ""
}
parsed := gjson.ParseBytes(body)
return extractTextByPaths(parsed, fallbackPaths)
}
// autoExtractStreamingResponseContent tries configured fallback paths to extract text content.
// It handles both bare JSON and SSE "data:" payloads, including multi-line data events.
func autoExtractStreamingResponseContent(chunk []byte, fallbackPaths []string) string {
if len(fallbackPaths) == 0 {
return ""
}
payload := bytes.TrimSpace(chunk)
if len(payload) == 0 {
return ""
}
if !isJSONPayload(payload) {
payload = extractSSEDataPayload(payload)
if len(payload) == 0 {
return ""
}
}
if !json.Valid(payload) {
return ""
}
parsed := gjson.ParseBytes(payload)
return extractTextByPaths(parsed, fallbackPaths)
}
func isJSONPayload(payload []byte) bool {
return len(payload) > 0 && (payload[0] == '{' || payload[0] == '[')
}
// extractSSEDataPayload concatenates all "data:" lines in one SSE event.
// SSE specifies multi-line data fields should be joined with '\n'.
func extractSSEDataPayload(chunk []byte) []byte {
lines := bytes.Split(chunk, []byte("\n"))
dataLines := make([][]byte, 0, len(lines))
for _, line := range lines {
line = bytes.TrimSpace(line)
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
data := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
return nil
}
dataLines = append(dataLines, data)
}
if len(dataLines) == 0 {
return nil
}
return bytes.TrimSpace(bytes.Join(dataLines, []byte("\n")))
}
func buildEffectiveFallbackPaths(primaryPath string, fallbackPaths []string) []string {
primaryPath = strings.TrimSpace(primaryPath)
if len(fallbackPaths) == 0 {
return []string{}
}
deduped := make([]string, 0, len(fallbackPaths))
seen := make(map[string]struct{}, len(fallbackPaths))
for _, path := range fallbackPaths {
path = strings.TrimSpace(path)
if len(path) == 0 || path == primaryPath {
continue
}
if _, ok := seen[path]; ok {
continue
}
seen[path] = struct{}{}
deduped = append(deduped, path)
}
if len(deduped) == 0 {
return []string{}
}
return deduped
}
type fallbackPathContext interface {
GetContext(key string) interface{}
SetContext(key string, value interface{})
}
func getEffectiveFallbackPathsFromContext(ctx fallbackPathContext, ctxKey string, primaryPath string, fallbackPaths []string) []string {
if cached, ok := ctx.GetContext(ctxKey).([]string); ok {
return cached
}
effective := buildEffectiveFallbackPaths(primaryPath, fallbackPaths)
ctx.SetContext(ctxKey, effective)
return effective
}
func extractTextByPaths(parsed gjson.Result, paths []string) string {
for _, path := range paths {
path = strings.TrimSpace(path)
if len(path) == 0 {
continue
}
result := parsed.Get(path)
if !result.Exists() {
continue
}
if text := extractTextFromResult(result); len(text) > 0 {
log.Debugf("response fallback path matched: %s", path)
return text
}
}
return ""
}
func extractTextFromResult(result gjson.Result) string {
if result.IsArray() {
var parts []string
for _, item := range result.Array() {
if s := item.String(); len(s) > 0 {
parts = append(parts, s)
}
}
return strings.Join(parts, "")
}
return result.String()
}
// autoExtractStreamingResponseFromSSE tries configured fallback paths on a full SSE body.
func autoExtractStreamingResponseFromSSE(data []byte, fallbackPaths []string) string {
if len(fallbackPaths) == 0 {
return ""
}
chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n"))
var parts []string
for _, chunk := range chunks {
if s := autoExtractStreamingResponseContent(chunk, fallbackPaths); len(s) > 0 {
parts = append(parts, s)
}
}
return strings.Join(parts, "")
}

View File

@@ -0,0 +1,377 @@
package text
import (
"os"
"testing"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
wasmlog "github.com/higress-group/wasm-go/pkg/log"
"github.com/tidwall/gjson"
)
type noopPluginLog struct{}
func (noopPluginLog) Trace(string) {}
func (noopPluginLog) Tracef(string, ...interface{}) {}
func (noopPluginLog) Debug(string) {}
func (noopPluginLog) Debugf(string, ...interface{}) {}
func (noopPluginLog) Info(string) {}
func (noopPluginLog) Infof(string, ...interface{}) {}
func (noopPluginLog) Warn(string) {}
func (noopPluginLog) Warnf(string, ...interface{}) {}
func (noopPluginLog) Error(string) {}
func (noopPluginLog) Errorf(string, ...interface{}) {}
func (noopPluginLog) Critical(string) {}
func (noopPluginLog) Criticalf(string, ...interface{}) {}
func (noopPluginLog) ResetID(string) {}
func TestMain(m *testing.M) {
wasmlog.SetPluginLog(noopPluginLog{})
os.Exit(m.Run())
}
type fallbackPathMockContext struct {
values map[string]interface{}
}
func (m *fallbackPathMockContext) GetContext(key string) interface{} {
return m.values[key]
}
func (m *fallbackPathMockContext) SetContext(key string, value interface{}) {
if m.values == nil {
m.values = make(map[string]interface{})
}
m.values[key] = value
}
func TestAutoExtractResponseContent(t *testing.T) {
tests := []struct {
name string
body string
fallbackPaths []string
want string
}{
{
name: "OpenAI format",
body: `{"choices":[{"message":{"content":"hello world"}}]}`,
want: "hello world",
},
{
name: "Claude format simple",
body: `{"content":[{"type":"text","text":"hello claude"}]}`,
want: "hello claude",
},
{
name: "Claude format with thinking block first",
body: `{"content":[{"type":"thinking","thinking":"let me think..."},{"type":"text","text":"hello after thinking"}]}`,
want: "hello after thinking",
},
{
name: "Claude format multiple text blocks concatenated",
body: `{"content":[{"type":"thinking","thinking":"..."},{"type":"text","text":"first"},{"type":"text","text":" second"}]}`,
want: "first second",
},
{
name: "Claude format first text block empty, second non-empty",
body: `{"content":[{"type":"text","text":""},{"type":"text","text":"actual content"}]}`,
want: "actual content",
},
{
name: "empty body",
body: `{}`,
want: "",
},
{
name: "no matching format",
body: `{"result":"some other format"}`,
want: "",
},
{
name: "custom fallback path",
body: `{"output":{"text":"custom fallback text"}}`,
fallbackPaths: []string{"output.text"},
want: "custom fallback text",
},
{
name: "fallback path list with empty item",
body: `{"output":{"text":"custom fallback text"}}`,
fallbackPaths: []string{" ", "output.text"},
want: "custom fallback text",
},
{
name: "fallback disabled explicitly",
body: `{"choices":[{"message":{"content":"hello world"}}]}`,
fallbackPaths: []string{},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fallbackPaths := tt.fallbackPaths
if fallbackPaths == nil {
fallbackPaths = cfg.DefaultResponseFallbackJsonPaths()
}
got := autoExtractResponseContent([]byte(tt.body), fallbackPaths)
if got != tt.want {
t.Errorf("autoExtractResponseContent() = %q, want %q", got, tt.want)
}
})
}
}
func TestAutoExtractStreamingResponseContent(t *testing.T) {
tests := []struct {
name string
chunk string
fallbackPaths []string
want string
}{
{
name: "OpenAI streaming format",
chunk: `{"choices":[{"delta":{"content":"hello"}}]}`,
want: "hello",
},
{
name: "Claude streaming format",
chunk: `{"type":"content_block_delta","delta":{"type":"text_delta","text":"hello claude"}}`,
want: "hello claude",
},
{
name: "Claude thinking delta - no text extracted",
chunk: `{"type":"content_block_delta","delta":{"type":"thinking_delta","thinking":"let me think"}}`,
want: "",
},
{
name: "empty chunk",
chunk: `{}`,
want: "",
},
{
name: "OpenAI with data: prefix",
chunk: "data: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}",
want: "hello",
},
{
name: "Claude with event: and data: prefix",
chunk: "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}",
want: "hello",
},
{
name: "OpenAI with multi-line data fields",
chunk: `event: message
data: {
data: "choices": [{"delta": {"content": "hello multiline"}}]
data: }`,
want: "hello multiline",
},
{
name: "data: [DONE] returns empty",
chunk: "data: [DONE]",
want: "",
},
{
name: "custom streaming fallback path",
chunk: `{"payload":{"delta":"custom stream"}}`,
fallbackPaths: []string{"payload.delta"},
want: "custom stream",
},
{
name: "streaming fallback disabled explicitly",
chunk: `{"choices":[{"delta":{"content":"hello"}}]}`,
fallbackPaths: []string{},
want: "",
},
{
name: "empty chunk payload",
chunk: "",
want: "",
},
{
name: "invalid json payload after data extraction",
chunk: "data: invalid-json",
want: "",
},
{
name: "streaming payload with empty data line",
chunk: "event: message\ndata:\ndata: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}",
want: "hello",
},
{
name: "streaming payload without data lines",
chunk: "event: ping",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fallbackPaths := tt.fallbackPaths
if fallbackPaths == nil {
fallbackPaths = cfg.DefaultStreamingResponseFallbackJsonPaths()
}
got := autoExtractStreamingResponseContent([]byte(tt.chunk), fallbackPaths)
if got != tt.want {
t.Errorf("autoExtractStreamingResponseContent() = %q, want %q", got, tt.want)
}
})
}
}
// Test that configured path takes priority over fallback.
func TestConfiguredPathPriority(t *testing.T) {
// Body has both OpenAI and a custom field
body := `{"choices":[{"message":{"content":"openai content"}}],"custom":"custom content"}`
// Custom path extracts successfully - should NOT fall back
content := extractWithFallback([]byte(body), "custom", cfg.DefaultResponseFallbackJsonPaths())
if content != "custom content" {
t.Errorf("expected custom path to take priority, got %q", content)
}
// Custom path misses - should fall back to OpenAI
content = extractWithFallback([]byte(body), "nonexistent.path", cfg.DefaultResponseFallbackJsonPaths())
if content != "openai content" {
t.Errorf("expected fallback to OpenAI, got %q", content)
}
// Fallback disabled - should stay empty when configured path misses.
content = extractWithFallback([]byte(body), "nonexistent.path", []string{})
if content != "" {
t.Errorf("expected empty result when fallback disabled, got %q", content)
}
}
// extractWithFallback mirrors the real extraction logic in HandleTextGenerationResponseBody.
func extractWithFallback(body []byte, jsonPath string, fallbackPaths []string) string {
content := gjsonGetString(body, jsonPath)
if len(content) == 0 {
content = autoExtractResponseContent(body, fallbackPaths)
}
return content
}
func gjsonGetString(body []byte, path string) string {
return gjson.GetBytes(body, path).String()
}
// Test SSE body fallback for buffered streaming branch.
func TestAutoExtractStreamingResponseFromSSE(t *testing.T) {
tests := []struct {
name string
body string
fallbackPaths []string
want string
}{
{
name: "OpenAI SSE body",
body: "data: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\ndata: {\"choices\":[{\"delta\":{\"content\":\" world\"}}]}\n\ndata: [DONE]\n\n",
want: "hello world",
},
{
name: "Claude SSE body with thinking and text deltas",
body: "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"hmm\"}}\n\n" +
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\n" +
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\" claude\"}}\n\n" +
"data: [DONE]\n\n",
want: "hello claude",
},
{
name: "empty SSE body",
body: "data: [DONE]\n\n",
want: "",
},
{
name: "OpenAI multi-line data events in full SSE body",
body: `event: message
data: {
data: "choices": [{"delta": {"content": "hello"}}]
data: }
event: message
data: {
data: "choices": [{"delta": {"content": " world"}}]
data: }
data: [DONE]
`,
want: "hello world",
},
{
name: "custom fallback paths in full SSE body",
body: "data: {\"payload\":{\"delta\":\"hello\"}}\n\ndata: {\"payload\":{\"delta\":\" world\"}}\n\n",
fallbackPaths: []string{
"payload.delta",
},
want: "hello world",
},
{
name: "streaming fallback disabled for full SSE body",
body: "data: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\n",
fallbackPaths: []string{},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fallbackPaths := tt.fallbackPaths
if fallbackPaths == nil {
fallbackPaths = cfg.DefaultStreamingResponseFallbackJsonPaths()
}
got := autoExtractStreamingResponseFromSSE([]byte(tt.body), fallbackPaths)
if got != tt.want {
t.Errorf("autoExtractStreamingResponseFromSSE() = %q, want %q", got, tt.want)
}
})
}
}
func TestBuildEffectiveFallbackPaths(t *testing.T) {
if paths := buildEffectiveFallbackPaths("choices.0.message.content", nil); len(paths) != 0 {
t.Fatalf("expected empty paths when fallback list is nil, got %#v", paths)
}
emptyByFilter := buildEffectiveFallbackPaths("choices.0.message.content", []string{
"choices.0.message.content",
" ",
"",
})
if len(emptyByFilter) != 0 {
t.Fatalf("expected empty paths after filtering duplicates/empty values, got %#v", emptyByFilter)
}
paths := buildEffectiveFallbackPaths("choices.0.message.content", []string{
"choices.0.message.content",
"delta.text",
"delta.text",
"",
" ",
"output.text",
})
if len(paths) != 2 {
t.Fatalf("expected 2 paths after filtering, got %d", len(paths))
}
if paths[0] != "delta.text" || paths[1] != "output.text" {
t.Fatalf("unexpected filtered fallback paths: %#v", paths)
}
}
func TestGetEffectiveFallbackPathsFromContext(t *testing.T) {
ctx := &fallbackPathMockContext{values: make(map[string]interface{})}
got := getEffectiveFallbackPathsFromContext(ctx, "fallback_key", "choices.0.message.content", []string{
"choices.0.message.content",
"output.text",
})
if len(got) != 1 || got[0] != "output.text" {
t.Fatalf("unexpected effective paths from uncached context: %#v", got)
}
if cached, ok := ctx.values["fallback_key"].([]string); !ok || len(cached) != 1 || cached[0] != "output.text" {
t.Fatalf("expected effective paths to be cached in context, got %#v", ctx.values["fallback_key"])
}
ctx.values["fallback_key"] = []string{"cached.path"}
got = getEffectiveFallbackPathsFromContext(ctx, "fallback_key", "nonexistent", []string{"another.path"})
if len(got) != 1 || got[0] != "cached.path" {
t.Fatalf("expected cached paths to take precedence, got %#v", got)
}
}

View File

@@ -333,6 +333,8 @@ func TestParseConfig(t *testing.T) {
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
require.Equal(t, uint32(2000), securityConfig.Timeout)
require.Equal(t, 1000, securityConfig.BufferLimit)
require.Equal(t, cfg.DefaultResponseFallbackJsonPaths(), securityConfig.ResponseContentFallbackJsonPaths)
require.Equal(t, cfg.DefaultStreamingResponseFallbackJsonPaths(), securityConfig.ResponseStreamContentFallbackJsonPaths)
})
// 测试仅检查请求的配置
@@ -390,6 +392,116 @@ func TestParseConfig(t *testing.T) {
require.Equal(t, "high", securityConfig.GetMaliciousUrlLevelBar("cc"))
require.Equal(t, "low", securityConfig.GetMaliciousUrlLevelBar("ccc-regexp-test"))
})
t.Run("custom response fallback paths config", func(t *testing.T) {
configJSON, err := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkResponse": true,
"responseContentFallbackJsonPaths": []string{"output.text", "choices.0.message.content"},
"responseStreamContentFallbackJsonPaths": []string{"payload.delta", "delta.text"},
})
require.NoError(t, err)
host, status := test.NewTestHost(configJSON)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, []string{"output.text", "choices.0.message.content"}, securityConfig.ResponseContentFallbackJsonPaths)
require.Equal(t, []string{"payload.delta", "delta.text"}, securityConfig.ResponseStreamContentFallbackJsonPaths)
})
t.Run("empty response fallback paths disable fallback", func(t *testing.T) {
configJSON, err := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkResponse": true,
"responseContentFallbackJsonPaths": []string{},
"responseStreamContentFallbackJsonPaths": []string{},
})
require.NoError(t, err)
host, status := test.NewTestHost(configJSON)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
securityConfig := config.(*cfg.AISecurityConfig)
require.Len(t, securityConfig.ResponseContentFallbackJsonPaths, 0)
require.Len(t, securityConfig.ResponseStreamContentFallbackJsonPaths, 0)
})
t.Run("invalid response fallback paths type", func(t *testing.T) {
configJSON, err := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkResponse": true,
"responseContentFallbackJsonPaths": "choices.0.message.content",
})
require.NoError(t, err)
host, status := test.NewTestHost(configJSON)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
t.Run("invalid response fallback paths item", func(t *testing.T) {
configJSON, err := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkResponse": true,
"responseStreamContentFallbackJsonPaths": []interface{}{"delta.text", ""},
})
require.NoError(t, err)
host, status := test.NewTestHost(configJSON)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
t.Run("invalid response fallback paths non-string item", func(t *testing.T) {
configJSON, err := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkResponse": true,
"responseStreamContentFallbackJsonPaths": []interface{}{"delta.text", 123},
})
require.NoError(t, err)
host, status := test.NewTestHost(configJSON)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
t.Run("invalid contentModerationLevelBar value", func(t *testing.T) {
configJSON, err := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkResponse": true,
"contentModerationLevelBar": "invalid",
})
require.NoError(t, err)
host, status := test.NewTestHost(configJSON)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
})
}
@@ -632,6 +744,100 @@ func TestOnHttpResponseBody(t *testing.T) {
})
}
func TestResponseFallbackExtractionCoverage(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
base := map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkResponse": true,
"action": "MultiModalGuard",
"apiType": "text_generation",
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"timeout": 2000,
"bufferLimit": 1000,
}
withOverrides := func(overrides map[string]interface{}) json.RawMessage {
cfgMap := make(map[string]interface{}, len(base)+len(overrides))
for k, v := range base {
cfgMap[k] = v
}
for k, v := range overrides {
cfgMap[k] = v
}
data, err := json.Marshal(cfgMap)
require.NoError(t, err)
return data
}
t.Run("streaming response chunk uses configured fallback path", func(t *testing.T) {
host, status := test.NewTestHost(withOverrides(map[string]interface{}{
"responseStreamContentJsonPath": "nonexistent.path",
"responseStreamContentFallbackJsonPaths": []string{"choices.0.delta.content"},
}))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
action := host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "text/event-stream"},
})
require.Equal(t, types.ActionContinue, action)
chunk := []byte("data: {\"choices\":[{\"delta\":{\"content\":\"hello fallback\"}}]}\n\n")
host.CallOnHttpStreamingResponseBody(chunk, true)
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-stream-fallback", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
host.CompleteHttp()
})
t.Run("buffered response body uses streaming fallback extraction", func(t *testing.T) {
host, status := test.NewTestHost(withOverrides(map[string]interface{}{
"responseStreamContentJsonPath": "nonexistent.path",
"responseStreamContentFallbackJsonPaths": []string{"choices.0.delta.content"},
}))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "text/event-stream"},
})
body := "data: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\ndata: [DONE]\n\n"
host.CallOnHttpResponseBody([]byte(body))
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-buffered-stream-fallback", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
host.CompleteHttp()
})
})
}
func TestMCP(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// Test MCP Response Body Check - Pass