diff --git a/plugins/wasm-go/extensions/ai-security-guard/README.md b/plugins/wasm-go/extensions/ai-security-guard/README.md index 885fbdc4d..5d547e9f3 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README.md @@ -29,6 +29,8 @@ description: 阿里云内容安全检测 | `requestContentJsonPath` | string | optional | `messages.@reverse.0.content` | 指定要检测内容在请求body中的jsonpath | | `responseContentJsonPath` | string | optional | `choices.0.message.content` | 指定要检测内容在响应body中的jsonpath | | `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | 指定要检测内容在流式响应body中的jsonpath | +| `responseContentFallbackJsonPaths` | array | optional | [`choices.0.message.content`, `content.#(type=="text")#.text`] | 当 `responseContentJsonPath` 提取为空时,按顺序尝试这些兜底路径;与主路径相同的项会自动跳过;显式配置为空数组 `[]` 可禁用兜底 | +| `responseStreamContentFallbackJsonPaths` | array | optional | [`choices.0.delta.content`, `delta.text`] | 当 `responseStreamContentJsonPath` 提取为空时,按顺序尝试这些流式兜底路径;与主路径相同的项会自动跳过;显式配置为空数组 `[]` 可禁用兜底 | | `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 | | `denyMessage` | string | optional | openai格式的流式/非流式响应 | 指定内容非法时的响应内容 | | `protocol` | string | optional | openai | 协议格式,非openai协议填`original` | @@ -211,6 +213,34 @@ denyMessage: "很抱歉,我无法回答您的问题" protocol: original ``` +### 配置响应内容兜底提取路径 + +当主路径提取不到内容时,可按优先级顺序配置兜底路径,兼容多种返回协议: + +```yaml +serviceName: safecheck.dns +servicePort: 443 +serviceHost: "green-cip.cn-shanghai.aliyuncs.com" +accessKey: "XXXXXXXXX" +secretKey: "XXXXXXXXXXXXXXX" +checkResponse: true +responseContentJsonPath: "choices.0.message.content" +responseStreamContentJsonPath: "choices.0.delta.content" +responseContentFallbackJsonPaths: + - "output.text" + - 'content.#(type=="text")#.text' +responseStreamContentFallbackJsonPaths: + - "payload.delta" + - "delta.text" +``` + +如需严格模式(主路径未命中即跳过,不走兜底),可显式关闭兜底: + +```yaml +responseContentFallbackJsonPaths: [] +responseStreamContentFallbackJsonPaths: [] +``` + ## 可观测 ### Metric ai-security-guard 插件提供了以下监控指标: diff --git a/plugins/wasm-go/extensions/ai-security-guard/README_EN.md b/plugins/wasm-go/extensions/ai-security-guard/README_EN.md index 61849763c..08fa672b8 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README_EN.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README_EN.md @@ -29,6 +29,8 @@ Plugin Priority: `300` | `requestContentJsonPath` | string | optional | `messages.@reverse.0.content` | Specify the jsonpath of the content to be detected in the request body | | `responseContentJsonPath` | string | optional | `choices.0.message.content` | Specify the jsonpath of the content to be detected in the response body | | `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | Specify the jsonpath of the content to be detected in the streaming response body | +| `responseContentFallbackJsonPaths` | array | optional | [`choices.0.message.content`, `content.#(type=="text")#.text`] | Fallback paths tried in order when `responseContentJsonPath` extracts empty content; entries equal to the primary path are skipped automatically; set to `[]` to disable fallback explicitly | +| `responseStreamContentFallbackJsonPaths` | array | optional | [`choices.0.delta.content`, `delta.text`] | Streaming fallback paths tried in order when `responseStreamContentJsonPath` extracts empty content; entries equal to the primary path are skipped automatically; set to `[]` to disable fallback explicitly | | `denyCode` | int | optional | 200 | Response status code when the specified content is illegal | | `denyMessage` | string | optional | Drainage/non-streaming response in openai format, the answer content is the suggested answer from Alibaba Cloud content security | Response content when the specified content is illegal | | `protocol` | string | optional | openai | protocol format, `openai` or `original` | @@ -129,6 +131,34 @@ checkRequest: true checkResponse: true ``` +### Configure response fallback extraction paths + +When primary extraction paths are empty, you can configure ordered fallback paths to support multiple response formats: + +```yaml +serviceName: safecheck.dns +servicePort: 443 +serviceHost: green-cip.cn-shanghai.aliyuncs.com +accessKey: "XXXXXXXXX" +secretKey: "XXXXXXXXXXXXXXX" +checkResponse: true +responseContentJsonPath: "choices.0.message.content" +responseStreamContentJsonPath: "choices.0.delta.content" +responseContentFallbackJsonPaths: + - "output.text" + - 'content.#(type=="text")#.text' +responseStreamContentFallbackJsonPaths: + - "payload.delta" + - "delta.text" +``` + +To enforce strict mode (no fallback), configure both fields as empty arrays: + +```yaml +responseContentFallbackJsonPaths: [] +responseStreamContentFallbackJsonPaths: [] +``` + ## Observability ### Metric ai-security-guard plugin provides following metrics: diff --git a/plugins/wasm-go/extensions/ai-security-guard/config/config.go b/plugins/wasm-go/extensions/ai-security-guard/config/config.go index ed3ea595d..e3c9fb252 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/config/config.go +++ b/plugins/wasm-go/extensions/ai-security-guard/config/config.go @@ -67,6 +67,26 @@ const ( DefaultTextModerationPlusTextOutputCheckService = "llm_response_moderation" ) +var ( + // Keep these defaults aligned with previous hardcoded fallback extraction behavior. + defaultResponseFallbackJsonPaths = []string{ + "choices.0.message.content", + `content.#(type=="text")#.text`, + } + defaultStreamingResponseFallbackJsonPaths = []string{ + "choices.0.delta.content", + "delta.text", + } +) + +func DefaultResponseFallbackJsonPaths() []string { + return append([]string(nil), defaultResponseFallbackJsonPaths...) +} + +func DefaultStreamingResponseFallbackJsonPaths() []string { + return append([]string(nil), defaultStreamingResponseFallbackJsonPaths...) +} + // api types const ( @@ -143,38 +163,40 @@ func (m *Matcher) match(consumer string) bool { } type AISecurityConfig struct { - Client wrapper.HttpClient - Host string - AK string - SK string - Token string - Action string - CheckRequest bool - CheckRequestImage bool - RequestCheckService string - RequestImageCheckService string - RequestContentJsonPath string - CheckResponse bool - ResponseCheckService string - ResponseImageCheckService string - ResponseContentJsonPath string - ResponseStreamContentJsonPath string - DenyCode int64 - DenyMessage string - ProtocolOriginal bool - RiskLevelBar string - ContentModerationLevelBar string - PromptAttackLevelBar string - SensitiveDataLevelBar string - MaliciousUrlLevelBar string - ModelHallucinationLevelBar string - CustomLabelLevelBar string - Timeout uint32 - BufferLimit int - Metrics map[string]proxywasm.MetricCounter - ConsumerRequestCheckService []map[string]interface{} - ConsumerResponseCheckService []map[string]interface{} - ConsumerRiskLevel []map[string]interface{} + Client wrapper.HttpClient + Host string + AK string + SK string + Token string + Action string + CheckRequest bool + CheckRequestImage bool + RequestCheckService string + RequestImageCheckService string + RequestContentJsonPath string + CheckResponse bool + ResponseCheckService string + ResponseImageCheckService string + ResponseContentJsonPath string + ResponseStreamContentJsonPath string + ResponseContentFallbackJsonPaths []string + ResponseStreamContentFallbackJsonPaths []string + DenyCode int64 + DenyMessage string + ProtocolOriginal bool + RiskLevelBar string + ContentModerationLevelBar string + PromptAttackLevelBar string + SensitiveDataLevelBar string + MaliciousUrlLevelBar string + ModelHallucinationLevelBar string + CustomLabelLevelBar string + Timeout uint32 + BufferLimit int + Metrics map[string]proxywasm.MetricCounter + ConsumerRequestCheckService []map[string]interface{} + ConsumerResponseCheckService []map[string]interface{} + ConsumerRiskLevel []map[string]interface{} // text_generation, image_generation, etc. ApiType string // openai, qwen, comfyui, etc. @@ -287,6 +309,16 @@ func (config *AISecurityConfig) Parse(json gjson.Result) error { if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() { config.ResponseStreamContentJsonPath = obj.String() } + if paths, exists, err := parseOptionalStringArrayConfig(json, "responseContentFallbackJsonPaths"); err != nil { + return err + } else if exists { + config.ResponseContentFallbackJsonPaths = paths + } + if paths, exists, err := parseOptionalStringArrayConfig(json, "responseStreamContentFallbackJsonPaths"); err != nil { + return err + } else if exists { + config.ResponseStreamContentFallbackJsonPaths = paths + } if obj := json.Get("contentModerationLevelBar"); obj.Exists() { config.ContentModerationLevelBar = obj.String() if LevelToInt(config.ContentModerationLevelBar) <= 0 { @@ -448,6 +480,29 @@ func parseDimensionAction(json gjson.Result, fieldName string) (string, error) { return "", nil } +func parseOptionalStringArrayConfig(json gjson.Result, fieldName string) ([]string, bool, error) { + obj := json.Get(fieldName) + if !obj.Exists() { + return nil, false, nil + } + if !obj.IsArray() { + return nil, true, fmt.Errorf("invalid %s, value must be an array of non-empty strings", fieldName) + } + items := obj.Array() + paths := make([]string, 0, len(items)) + for _, item := range items { + if item.Type != gjson.String { + return nil, true, fmt.Errorf("invalid %s, value must be an array of non-empty strings", fieldName) + } + path := strings.TrimSpace(item.String()) + if path == "" { + return nil, true, fmt.Errorf("invalid %s, value must be an array of non-empty strings", fieldName) + } + paths = append(paths, path) + } + return paths, true, nil +} + func (config *AISecurityConfig) SetDefaultValues() { switch config.Action { case TextModerationPlus: @@ -463,6 +518,8 @@ func (config *AISecurityConfig) SetDefaultValues() { config.RequestContentJsonPath = DefaultRequestJsonPath config.ResponseContentJsonPath = DefaultResponseJsonPath config.ResponseStreamContentJsonPath = DefaultStreamingResponseJsonPath + config.ResponseContentFallbackJsonPaths = DefaultResponseFallbackJsonPaths() + config.ResponseStreamContentFallbackJsonPaths = DefaultStreamingResponseFallbackJsonPaths() config.ContentModerationLevelBar = MaxRisk config.PromptAttackLevelBar = MaxRisk config.SensitiveDataLevelBar = S4Sensitive diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text/openai.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text/openai.go index 00faca425..e736d55ac 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text/openai.go +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text/openai.go @@ -18,11 +18,18 @@ import ( "github.com/tidwall/gjson" ) +const ( + responseFallbackPathsCtxKey = "response_fallback_paths" + responseStreamFallbackPathsCtxKey = "response_stream_fallback_paths" +) + func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action { contentType, _ := proxywasm.GetHttpResponseHeader("content-type") ctx.SetContext("end_of_stream_received", false) ctx.SetContext("during_call", false) ctx.SetContext("risk_detected", false) + ctx.SetContext(responseFallbackPathsCtxKey, buildEffectiveFallbackPaths(config.ResponseContentJsonPath, config.ResponseContentFallbackJsonPaths)) + ctx.SetContext(responseStreamFallbackPathsCtxKey, buildEffectiveFallbackPaths(config.ResponseStreamContentJsonPath, config.ResponseStreamContentFallbackJsonPaths)) sessionID, _ := utils.GenerateHexID(20) ctx.SetContext("sessionID", sessionID) if strings.Contains(contentType, "text/event-stream") { @@ -36,6 +43,7 @@ func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISe func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte { consumer, _ := ctx.GetContext("consumer").(string) + streamFallbackPaths := getEffectiveFallbackPathsFromContext(ctx, responseStreamFallbackPathsCtxKey, config.ResponseStreamContentJsonPath, config.ResponseStreamContentFallbackJsonPaths) var sessionID string if ctx.GetContext("sessionID") == nil { sessionID, _ = utils.GenerateHexID(20) @@ -101,6 +109,9 @@ func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config c front := ctx.PopBuffer() bufferQueue = append(bufferQueue, front) msg := gjson.GetBytes(front, config.ResponseStreamContentJsonPath).String() + if len(msg) == 0 { + msg = autoExtractStreamingResponseContent(front, streamFallbackPaths) + } buffer += msg if len([]rune(buffer)) >= config.BufferLimit { break @@ -162,6 +173,8 @@ func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config c func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action { consumer, _ := ctx.GetContext("consumer").(string) + responseFallbackPaths := getEffectiveFallbackPathsFromContext(ctx, responseFallbackPathsCtxKey, config.ResponseContentJsonPath, config.ResponseContentFallbackJsonPaths) + streamFallbackPaths := getEffectiveFallbackPathsFromContext(ctx, responseStreamFallbackPathsCtxKey, config.ResponseStreamContentJsonPath, config.ResponseStreamContentFallbackJsonPaths) log.Debugf("checking response body...") startTime := time.Now().UnixMilli() contentType, _ := proxywasm.GetHttpResponseHeader("content-type") @@ -169,8 +182,14 @@ func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecu var content string if isStreamingResponse { content = utils.ExtractMessageFromStreamingBody(body, config.ResponseStreamContentJsonPath) + if len(content) == 0 { + content = autoExtractStreamingResponseFromSSE(body, streamFallbackPaths) + } } else { content = gjson.GetBytes(body, config.ResponseContentJsonPath).String() + if len(content) == 0 { + content = autoExtractResponseContent(body, responseFallbackPaths) + } } log.Debugf("Raw response content is: %s", content) if len(content) == 0 { @@ -255,3 +274,148 @@ func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecu singleCall() return types.ActionPause } + +// autoExtractResponseContent tries configured fallback paths to extract text content. +func autoExtractResponseContent(body []byte, fallbackPaths []string) string { + if len(fallbackPaths) == 0 { + return "" + } + parsed := gjson.ParseBytes(body) + return extractTextByPaths(parsed, fallbackPaths) +} + +// autoExtractStreamingResponseContent tries configured fallback paths to extract text content. +// It handles both bare JSON and SSE "data:" payloads, including multi-line data events. +func autoExtractStreamingResponseContent(chunk []byte, fallbackPaths []string) string { + if len(fallbackPaths) == 0 { + return "" + } + payload := bytes.TrimSpace(chunk) + if len(payload) == 0 { + return "" + } + if !isJSONPayload(payload) { + payload = extractSSEDataPayload(payload) + if len(payload) == 0 { + return "" + } + } + if !json.Valid(payload) { + return "" + } + parsed := gjson.ParseBytes(payload) + return extractTextByPaths(parsed, fallbackPaths) +} + +func isJSONPayload(payload []byte) bool { + return len(payload) > 0 && (payload[0] == '{' || payload[0] == '[') +} + +// extractSSEDataPayload concatenates all "data:" lines in one SSE event. +// SSE specifies multi-line data fields should be joined with '\n'. +func extractSSEDataPayload(chunk []byte) []byte { + lines := bytes.Split(chunk, []byte("\n")) + dataLines := make([][]byte, 0, len(lines)) + for _, line := range lines { + line = bytes.TrimSpace(line) + if !bytes.HasPrefix(line, []byte("data:")) { + continue + } + data := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:"))) + if len(data) == 0 { + continue + } + if bytes.Equal(data, []byte("[DONE]")) { + return nil + } + dataLines = append(dataLines, data) + } + if len(dataLines) == 0 { + return nil + } + return bytes.TrimSpace(bytes.Join(dataLines, []byte("\n"))) +} + +func buildEffectiveFallbackPaths(primaryPath string, fallbackPaths []string) []string { + primaryPath = strings.TrimSpace(primaryPath) + if len(fallbackPaths) == 0 { + return []string{} + } + deduped := make([]string, 0, len(fallbackPaths)) + seen := make(map[string]struct{}, len(fallbackPaths)) + for _, path := range fallbackPaths { + path = strings.TrimSpace(path) + if len(path) == 0 || path == primaryPath { + continue + } + if _, ok := seen[path]; ok { + continue + } + seen[path] = struct{}{} + deduped = append(deduped, path) + } + if len(deduped) == 0 { + return []string{} + } + return deduped +} + +type fallbackPathContext interface { + GetContext(key string) interface{} + SetContext(key string, value interface{}) +} + +func getEffectiveFallbackPathsFromContext(ctx fallbackPathContext, ctxKey string, primaryPath string, fallbackPaths []string) []string { + if cached, ok := ctx.GetContext(ctxKey).([]string); ok { + return cached + } + effective := buildEffectiveFallbackPaths(primaryPath, fallbackPaths) + ctx.SetContext(ctxKey, effective) + return effective +} + +func extractTextByPaths(parsed gjson.Result, paths []string) string { + for _, path := range paths { + path = strings.TrimSpace(path) + if len(path) == 0 { + continue + } + result := parsed.Get(path) + if !result.Exists() { + continue + } + if text := extractTextFromResult(result); len(text) > 0 { + log.Debugf("response fallback path matched: %s", path) + return text + } + } + return "" +} + +func extractTextFromResult(result gjson.Result) string { + if result.IsArray() { + var parts []string + for _, item := range result.Array() { + if s := item.String(); len(s) > 0 { + parts = append(parts, s) + } + } + return strings.Join(parts, "") + } + return result.String() +} + +// autoExtractStreamingResponseFromSSE tries configured fallback paths on a full SSE body. +func autoExtractStreamingResponseFromSSE(data []byte, fallbackPaths []string) string { + if len(fallbackPaths) == 0 { + return "" + } + chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) + var parts []string + for _, chunk := range chunks { + if s := autoExtractStreamingResponseContent(chunk, fallbackPaths); len(s) > 0 { + parts = append(parts, s) + } + } + return strings.Join(parts, "") +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text/openai_test.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text/openai_test.go new file mode 100644 index 000000000..3badc9e54 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text/openai_test.go @@ -0,0 +1,377 @@ +package text + +import ( + "os" + "testing" + + cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" + wasmlog "github.com/higress-group/wasm-go/pkg/log" + "github.com/tidwall/gjson" +) + +type noopPluginLog struct{} + +func (noopPluginLog) Trace(string) {} +func (noopPluginLog) Tracef(string, ...interface{}) {} +func (noopPluginLog) Debug(string) {} +func (noopPluginLog) Debugf(string, ...interface{}) {} +func (noopPluginLog) Info(string) {} +func (noopPluginLog) Infof(string, ...interface{}) {} +func (noopPluginLog) Warn(string) {} +func (noopPluginLog) Warnf(string, ...interface{}) {} +func (noopPluginLog) Error(string) {} +func (noopPluginLog) Errorf(string, ...interface{}) {} +func (noopPluginLog) Critical(string) {} +func (noopPluginLog) Criticalf(string, ...interface{}) {} +func (noopPluginLog) ResetID(string) {} + +func TestMain(m *testing.M) { + wasmlog.SetPluginLog(noopPluginLog{}) + os.Exit(m.Run()) +} + +type fallbackPathMockContext struct { + values map[string]interface{} +} + +func (m *fallbackPathMockContext) GetContext(key string) interface{} { + return m.values[key] +} + +func (m *fallbackPathMockContext) SetContext(key string, value interface{}) { + if m.values == nil { + m.values = make(map[string]interface{}) + } + m.values[key] = value +} + +func TestAutoExtractResponseContent(t *testing.T) { + tests := []struct { + name string + body string + fallbackPaths []string + want string + }{ + { + name: "OpenAI format", + body: `{"choices":[{"message":{"content":"hello world"}}]}`, + want: "hello world", + }, + { + name: "Claude format simple", + body: `{"content":[{"type":"text","text":"hello claude"}]}`, + want: "hello claude", + }, + { + name: "Claude format with thinking block first", + body: `{"content":[{"type":"thinking","thinking":"let me think..."},{"type":"text","text":"hello after thinking"}]}`, + want: "hello after thinking", + }, + { + name: "Claude format multiple text blocks concatenated", + body: `{"content":[{"type":"thinking","thinking":"..."},{"type":"text","text":"first"},{"type":"text","text":" second"}]}`, + want: "first second", + }, + { + name: "Claude format first text block empty, second non-empty", + body: `{"content":[{"type":"text","text":""},{"type":"text","text":"actual content"}]}`, + want: "actual content", + }, + { + name: "empty body", + body: `{}`, + want: "", + }, + { + name: "no matching format", + body: `{"result":"some other format"}`, + want: "", + }, + { + name: "custom fallback path", + body: `{"output":{"text":"custom fallback text"}}`, + fallbackPaths: []string{"output.text"}, + want: "custom fallback text", + }, + { + name: "fallback path list with empty item", + body: `{"output":{"text":"custom fallback text"}}`, + fallbackPaths: []string{" ", "output.text"}, + want: "custom fallback text", + }, + { + name: "fallback disabled explicitly", + body: `{"choices":[{"message":{"content":"hello world"}}]}`, + fallbackPaths: []string{}, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fallbackPaths := tt.fallbackPaths + if fallbackPaths == nil { + fallbackPaths = cfg.DefaultResponseFallbackJsonPaths() + } + got := autoExtractResponseContent([]byte(tt.body), fallbackPaths) + if got != tt.want { + t.Errorf("autoExtractResponseContent() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestAutoExtractStreamingResponseContent(t *testing.T) { + tests := []struct { + name string + chunk string + fallbackPaths []string + want string + }{ + { + name: "OpenAI streaming format", + chunk: `{"choices":[{"delta":{"content":"hello"}}]}`, + want: "hello", + }, + { + name: "Claude streaming format", + chunk: `{"type":"content_block_delta","delta":{"type":"text_delta","text":"hello claude"}}`, + want: "hello claude", + }, + { + name: "Claude thinking delta - no text extracted", + chunk: `{"type":"content_block_delta","delta":{"type":"thinking_delta","thinking":"let me think"}}`, + want: "", + }, + { + name: "empty chunk", + chunk: `{}`, + want: "", + }, + { + name: "OpenAI with data: prefix", + chunk: "data: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}", + want: "hello", + }, + { + name: "Claude with event: and data: prefix", + chunk: "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}", + want: "hello", + }, + { + name: "OpenAI with multi-line data fields", + chunk: `event: message +data: { +data: "choices": [{"delta": {"content": "hello multiline"}}] +data: }`, + want: "hello multiline", + }, + { + name: "data: [DONE] returns empty", + chunk: "data: [DONE]", + want: "", + }, + { + name: "custom streaming fallback path", + chunk: `{"payload":{"delta":"custom stream"}}`, + fallbackPaths: []string{"payload.delta"}, + want: "custom stream", + }, + { + name: "streaming fallback disabled explicitly", + chunk: `{"choices":[{"delta":{"content":"hello"}}]}`, + fallbackPaths: []string{}, + want: "", + }, + { + name: "empty chunk payload", + chunk: "", + want: "", + }, + { + name: "invalid json payload after data extraction", + chunk: "data: invalid-json", + want: "", + }, + { + name: "streaming payload with empty data line", + chunk: "event: message\ndata:\ndata: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}", + want: "hello", + }, + { + name: "streaming payload without data lines", + chunk: "event: ping", + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fallbackPaths := tt.fallbackPaths + if fallbackPaths == nil { + fallbackPaths = cfg.DefaultStreamingResponseFallbackJsonPaths() + } + got := autoExtractStreamingResponseContent([]byte(tt.chunk), fallbackPaths) + if got != tt.want { + t.Errorf("autoExtractStreamingResponseContent() = %q, want %q", got, tt.want) + } + }) + } +} + +// Test that configured path takes priority over fallback. +func TestConfiguredPathPriority(t *testing.T) { + // Body has both OpenAI and a custom field + body := `{"choices":[{"message":{"content":"openai content"}}],"custom":"custom content"}` + + // Custom path extracts successfully - should NOT fall back + content := extractWithFallback([]byte(body), "custom", cfg.DefaultResponseFallbackJsonPaths()) + if content != "custom content" { + t.Errorf("expected custom path to take priority, got %q", content) + } + + // Custom path misses - should fall back to OpenAI + content = extractWithFallback([]byte(body), "nonexistent.path", cfg.DefaultResponseFallbackJsonPaths()) + if content != "openai content" { + t.Errorf("expected fallback to OpenAI, got %q", content) + } + + // Fallback disabled - should stay empty when configured path misses. + content = extractWithFallback([]byte(body), "nonexistent.path", []string{}) + if content != "" { + t.Errorf("expected empty result when fallback disabled, got %q", content) + } +} + +// extractWithFallback mirrors the real extraction logic in HandleTextGenerationResponseBody. +func extractWithFallback(body []byte, jsonPath string, fallbackPaths []string) string { + content := gjsonGetString(body, jsonPath) + if len(content) == 0 { + content = autoExtractResponseContent(body, fallbackPaths) + } + return content +} + +func gjsonGetString(body []byte, path string) string { + return gjson.GetBytes(body, path).String() +} + +// Test SSE body fallback for buffered streaming branch. +func TestAutoExtractStreamingResponseFromSSE(t *testing.T) { + tests := []struct { + name string + body string + fallbackPaths []string + want string + }{ + { + name: "OpenAI SSE body", + body: "data: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\ndata: {\"choices\":[{\"delta\":{\"content\":\" world\"}}]}\n\ndata: [DONE]\n\n", + want: "hello world", + }, + { + name: "Claude SSE body with thinking and text deltas", + body: "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"hmm\"}}\n\n" + + "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\n" + + "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\" claude\"}}\n\n" + + "data: [DONE]\n\n", + want: "hello claude", + }, + { + name: "empty SSE body", + body: "data: [DONE]\n\n", + want: "", + }, + { + name: "OpenAI multi-line data events in full SSE body", + body: `event: message +data: { +data: "choices": [{"delta": {"content": "hello"}}] +data: } + +event: message +data: { +data: "choices": [{"delta": {"content": " world"}}] +data: } + +data: [DONE] + +`, + want: "hello world", + }, + { + name: "custom fallback paths in full SSE body", + body: "data: {\"payload\":{\"delta\":\"hello\"}}\n\ndata: {\"payload\":{\"delta\":\" world\"}}\n\n", + fallbackPaths: []string{ + "payload.delta", + }, + want: "hello world", + }, + { + name: "streaming fallback disabled for full SSE body", + body: "data: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\n", + fallbackPaths: []string{}, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fallbackPaths := tt.fallbackPaths + if fallbackPaths == nil { + fallbackPaths = cfg.DefaultStreamingResponseFallbackJsonPaths() + } + got := autoExtractStreamingResponseFromSSE([]byte(tt.body), fallbackPaths) + if got != tt.want { + t.Errorf("autoExtractStreamingResponseFromSSE() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestBuildEffectiveFallbackPaths(t *testing.T) { + if paths := buildEffectiveFallbackPaths("choices.0.message.content", nil); len(paths) != 0 { + t.Fatalf("expected empty paths when fallback list is nil, got %#v", paths) + } + + emptyByFilter := buildEffectiveFallbackPaths("choices.0.message.content", []string{ + "choices.0.message.content", + " ", + "", + }) + if len(emptyByFilter) != 0 { + t.Fatalf("expected empty paths after filtering duplicates/empty values, got %#v", emptyByFilter) + } + + paths := buildEffectiveFallbackPaths("choices.0.message.content", []string{ + "choices.0.message.content", + "delta.text", + "delta.text", + "", + " ", + "output.text", + }) + if len(paths) != 2 { + t.Fatalf("expected 2 paths after filtering, got %d", len(paths)) + } + if paths[0] != "delta.text" || paths[1] != "output.text" { + t.Fatalf("unexpected filtered fallback paths: %#v", paths) + } +} + +func TestGetEffectiveFallbackPathsFromContext(t *testing.T) { + ctx := &fallbackPathMockContext{values: make(map[string]interface{})} + got := getEffectiveFallbackPathsFromContext(ctx, "fallback_key", "choices.0.message.content", []string{ + "choices.0.message.content", + "output.text", + }) + if len(got) != 1 || got[0] != "output.text" { + t.Fatalf("unexpected effective paths from uncached context: %#v", got) + } + if cached, ok := ctx.values["fallback_key"].([]string); !ok || len(cached) != 1 || cached[0] != "output.text" { + t.Fatalf("expected effective paths to be cached in context, got %#v", ctx.values["fallback_key"]) + } + + ctx.values["fallback_key"] = []string{"cached.path"} + got = getEffectiveFallbackPathsFromContext(ctx, "fallback_key", "nonexistent", []string{"another.path"}) + if len(got) != 1 || got[0] != "cached.path" { + t.Fatalf("expected cached paths to take precedence, got %#v", got) + } +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/main_test.go b/plugins/wasm-go/extensions/ai-security-guard/main_test.go index 5969df7aa..841fdb3b4 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main_test.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main_test.go @@ -333,6 +333,8 @@ func TestParseConfig(t *testing.T) { require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar) require.Equal(t, uint32(2000), securityConfig.Timeout) require.Equal(t, 1000, securityConfig.BufferLimit) + require.Equal(t, cfg.DefaultResponseFallbackJsonPaths(), securityConfig.ResponseContentFallbackJsonPaths) + require.Equal(t, cfg.DefaultStreamingResponseFallbackJsonPaths(), securityConfig.ResponseStreamContentFallbackJsonPaths) }) // 测试仅检查请求的配置 @@ -390,6 +392,116 @@ func TestParseConfig(t *testing.T) { require.Equal(t, "high", securityConfig.GetMaliciousUrlLevelBar("cc")) require.Equal(t, "low", securityConfig.GetMaliciousUrlLevelBar("ccc-regexp-test")) }) + + t.Run("custom response fallback paths config", func(t *testing.T) { + configJSON, err := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkResponse": true, + "responseContentFallbackJsonPaths": []string{"output.text", "choices.0.message.content"}, + "responseStreamContentFallbackJsonPaths": []string{"payload.delta", "delta.text"}, + }) + require.NoError(t, err) + host, status := test.NewTestHost(configJSON) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + securityConfig := config.(*cfg.AISecurityConfig) + require.Equal(t, []string{"output.text", "choices.0.message.content"}, securityConfig.ResponseContentFallbackJsonPaths) + require.Equal(t, []string{"payload.delta", "delta.text"}, securityConfig.ResponseStreamContentFallbackJsonPaths) + }) + + t.Run("empty response fallback paths disable fallback", func(t *testing.T) { + configJSON, err := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkResponse": true, + "responseContentFallbackJsonPaths": []string{}, + "responseStreamContentFallbackJsonPaths": []string{}, + }) + require.NoError(t, err) + host, status := test.NewTestHost(configJSON) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + securityConfig := config.(*cfg.AISecurityConfig) + require.Len(t, securityConfig.ResponseContentFallbackJsonPaths, 0) + require.Len(t, securityConfig.ResponseStreamContentFallbackJsonPaths, 0) + }) + + t.Run("invalid response fallback paths type", func(t *testing.T) { + configJSON, err := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkResponse": true, + "responseContentFallbackJsonPaths": "choices.0.message.content", + }) + require.NoError(t, err) + host, status := test.NewTestHost(configJSON) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) + + t.Run("invalid response fallback paths item", func(t *testing.T) { + configJSON, err := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkResponse": true, + "responseStreamContentFallbackJsonPaths": []interface{}{"delta.text", ""}, + }) + require.NoError(t, err) + host, status := test.NewTestHost(configJSON) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) + + t.Run("invalid response fallback paths non-string item", func(t *testing.T) { + configJSON, err := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkResponse": true, + "responseStreamContentFallbackJsonPaths": []interface{}{"delta.text", 123}, + }) + require.NoError(t, err) + host, status := test.NewTestHost(configJSON) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) + + t.Run("invalid contentModerationLevelBar value", func(t *testing.T) { + configJSON, err := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkResponse": true, + "contentModerationLevelBar": "invalid", + }) + require.NoError(t, err) + host, status := test.NewTestHost(configJSON) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) }) } @@ -632,6 +744,100 @@ func TestOnHttpResponseBody(t *testing.T) { }) } +func TestResponseFallbackExtractionCoverage(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + base := map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkResponse": true, + "action": "MultiModalGuard", + "apiType": "text_generation", + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "sensitiveDataLevelBar": "S3", + "timeout": 2000, + "bufferLimit": 1000, + } + + withOverrides := func(overrides map[string]interface{}) json.RawMessage { + cfgMap := make(map[string]interface{}, len(base)+len(overrides)) + for k, v := range base { + cfgMap[k] = v + } + for k, v := range overrides { + cfgMap[k] = v + } + data, err := json.Marshal(cfgMap) + require.NoError(t, err) + return data + } + + t.Run("streaming response chunk uses configured fallback path", func(t *testing.T) { + host, status := test.NewTestHost(withOverrides(map[string]interface{}{ + "responseStreamContentJsonPath": "nonexistent.path", + "responseStreamContentFallbackJsonPaths": []string{"choices.0.delta.content"}, + })) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + action := host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "text/event-stream"}, + }) + require.Equal(t, types.ActionContinue, action) + + chunk := []byte("data: {\"choices\":[{\"delta\":{\"content\":\"hello fallback\"}}]}\n\n") + host.CallOnHttpStreamingResponseBody(chunk, true) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-stream-fallback", "Data": {"RiskLevel": "low"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + host.CompleteHttp() + }) + + t.Run("buffered response body uses streaming fallback extraction", func(t *testing.T) { + host, status := test.NewTestHost(withOverrides(map[string]interface{}{ + "responseStreamContentJsonPath": "nonexistent.path", + "responseStreamContentFallbackJsonPaths": []string{"choices.0.delta.content"}, + })) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "text/event-stream"}, + }) + + body := "data: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\ndata: [DONE]\n\n" + host.CallOnHttpResponseBody([]byte(body)) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-buffered-stream-fallback", "Data": {"RiskLevel": "low"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + host.CompleteHttp() + }) + }) +} + func TestMCP(t *testing.T) { test.RunTest(t, func(t *testing.T) { // Test MCP Response Body Check - Pass