mirror of
https://github.com/alibaba/higress.git
synced 2026-05-28 14:47:29 +08:00
feat(ai-security-guard): add fallback JSON paths for response content extraction (#3738)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: rinfx <yucheng.lxr@alibaba-inc.com>
This commit is contained in:
@@ -18,11 +18,18 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
responseFallbackPathsCtxKey = "response_fallback_paths"
|
||||
responseStreamFallbackPathsCtxKey = "response_stream_fallback_paths"
|
||||
)
|
||||
|
||||
func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
ctx.SetContext("end_of_stream_received", false)
|
||||
ctx.SetContext("during_call", false)
|
||||
ctx.SetContext("risk_detected", false)
|
||||
ctx.SetContext(responseFallbackPathsCtxKey, buildEffectiveFallbackPaths(config.ResponseContentJsonPath, config.ResponseContentFallbackJsonPaths))
|
||||
ctx.SetContext(responseStreamFallbackPathsCtxKey, buildEffectiveFallbackPaths(config.ResponseStreamContentJsonPath, config.ResponseStreamContentFallbackJsonPaths))
|
||||
sessionID, _ := utils.GenerateHexID(20)
|
||||
ctx.SetContext("sessionID", sessionID)
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
@@ -36,6 +43,7 @@ func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISe
|
||||
|
||||
func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
streamFallbackPaths := getEffectiveFallbackPathsFromContext(ctx, responseStreamFallbackPathsCtxKey, config.ResponseStreamContentJsonPath, config.ResponseStreamContentFallbackJsonPaths)
|
||||
var sessionID string
|
||||
if ctx.GetContext("sessionID") == nil {
|
||||
sessionID, _ = utils.GenerateHexID(20)
|
||||
@@ -101,6 +109,9 @@ func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config c
|
||||
front := ctx.PopBuffer()
|
||||
bufferQueue = append(bufferQueue, front)
|
||||
msg := gjson.GetBytes(front, config.ResponseStreamContentJsonPath).String()
|
||||
if len(msg) == 0 {
|
||||
msg = autoExtractStreamingResponseContent(front, streamFallbackPaths)
|
||||
}
|
||||
buffer += msg
|
||||
if len([]rune(buffer)) >= config.BufferLimit {
|
||||
break
|
||||
@@ -162,6 +173,8 @@ func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config c
|
||||
|
||||
func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
responseFallbackPaths := getEffectiveFallbackPathsFromContext(ctx, responseFallbackPathsCtxKey, config.ResponseContentJsonPath, config.ResponseContentFallbackJsonPaths)
|
||||
streamFallbackPaths := getEffectiveFallbackPathsFromContext(ctx, responseStreamFallbackPathsCtxKey, config.ResponseStreamContentJsonPath, config.ResponseStreamContentFallbackJsonPaths)
|
||||
log.Debugf("checking response body...")
|
||||
startTime := time.Now().UnixMilli()
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
@@ -169,8 +182,14 @@ func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecu
|
||||
var content string
|
||||
if isStreamingResponse {
|
||||
content = utils.ExtractMessageFromStreamingBody(body, config.ResponseStreamContentJsonPath)
|
||||
if len(content) == 0 {
|
||||
content = autoExtractStreamingResponseFromSSE(body, streamFallbackPaths)
|
||||
}
|
||||
} else {
|
||||
content = gjson.GetBytes(body, config.ResponseContentJsonPath).String()
|
||||
if len(content) == 0 {
|
||||
content = autoExtractResponseContent(body, responseFallbackPaths)
|
||||
}
|
||||
}
|
||||
log.Debugf("Raw response content is: %s", content)
|
||||
if len(content) == 0 {
|
||||
@@ -255,3 +274,148 @@ func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecu
|
||||
singleCall()
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
// autoExtractResponseContent tries configured fallback paths to extract text content.
|
||||
func autoExtractResponseContent(body []byte, fallbackPaths []string) string {
|
||||
if len(fallbackPaths) == 0 {
|
||||
return ""
|
||||
}
|
||||
parsed := gjson.ParseBytes(body)
|
||||
return extractTextByPaths(parsed, fallbackPaths)
|
||||
}
|
||||
|
||||
// autoExtractStreamingResponseContent tries configured fallback paths to extract text content.
|
||||
// It handles both bare JSON and SSE "data:" payloads, including multi-line data events.
|
||||
func autoExtractStreamingResponseContent(chunk []byte, fallbackPaths []string) string {
|
||||
if len(fallbackPaths) == 0 {
|
||||
return ""
|
||||
}
|
||||
payload := bytes.TrimSpace(chunk)
|
||||
if len(payload) == 0 {
|
||||
return ""
|
||||
}
|
||||
if !isJSONPayload(payload) {
|
||||
payload = extractSSEDataPayload(payload)
|
||||
if len(payload) == 0 {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
if !json.Valid(payload) {
|
||||
return ""
|
||||
}
|
||||
parsed := gjson.ParseBytes(payload)
|
||||
return extractTextByPaths(parsed, fallbackPaths)
|
||||
}
|
||||
|
||||
func isJSONPayload(payload []byte) bool {
|
||||
return len(payload) > 0 && (payload[0] == '{' || payload[0] == '[')
|
||||
}
|
||||
|
||||
// extractSSEDataPayload concatenates all "data:" lines in one SSE event.
|
||||
// SSE specifies multi-line data fields should be joined with '\n'.
|
||||
func extractSSEDataPayload(chunk []byte) []byte {
|
||||
lines := bytes.Split(chunk, []byte("\n"))
|
||||
dataLines := make([][]byte, 0, len(lines))
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(data, []byte("[DONE]")) {
|
||||
return nil
|
||||
}
|
||||
dataLines = append(dataLines, data)
|
||||
}
|
||||
if len(dataLines) == 0 {
|
||||
return nil
|
||||
}
|
||||
return bytes.TrimSpace(bytes.Join(dataLines, []byte("\n")))
|
||||
}
|
||||
|
||||
func buildEffectiveFallbackPaths(primaryPath string, fallbackPaths []string) []string {
|
||||
primaryPath = strings.TrimSpace(primaryPath)
|
||||
if len(fallbackPaths) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
deduped := make([]string, 0, len(fallbackPaths))
|
||||
seen := make(map[string]struct{}, len(fallbackPaths))
|
||||
for _, path := range fallbackPaths {
|
||||
path = strings.TrimSpace(path)
|
||||
if len(path) == 0 || path == primaryPath {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[path]; ok {
|
||||
continue
|
||||
}
|
||||
seen[path] = struct{}{}
|
||||
deduped = append(deduped, path)
|
||||
}
|
||||
if len(deduped) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
return deduped
|
||||
}
|
||||
|
||||
type fallbackPathContext interface {
|
||||
GetContext(key string) interface{}
|
||||
SetContext(key string, value interface{})
|
||||
}
|
||||
|
||||
func getEffectiveFallbackPathsFromContext(ctx fallbackPathContext, ctxKey string, primaryPath string, fallbackPaths []string) []string {
|
||||
if cached, ok := ctx.GetContext(ctxKey).([]string); ok {
|
||||
return cached
|
||||
}
|
||||
effective := buildEffectiveFallbackPaths(primaryPath, fallbackPaths)
|
||||
ctx.SetContext(ctxKey, effective)
|
||||
return effective
|
||||
}
|
||||
|
||||
func extractTextByPaths(parsed gjson.Result, paths []string) string {
|
||||
for _, path := range paths {
|
||||
path = strings.TrimSpace(path)
|
||||
if len(path) == 0 {
|
||||
continue
|
||||
}
|
||||
result := parsed.Get(path)
|
||||
if !result.Exists() {
|
||||
continue
|
||||
}
|
||||
if text := extractTextFromResult(result); len(text) > 0 {
|
||||
log.Debugf("response fallback path matched: %s", path)
|
||||
return text
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractTextFromResult(result gjson.Result) string {
|
||||
if result.IsArray() {
|
||||
var parts []string
|
||||
for _, item := range result.Array() {
|
||||
if s := item.String(); len(s) > 0 {
|
||||
parts = append(parts, s)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "")
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// autoExtractStreamingResponseFromSSE tries configured fallback paths on a full SSE body.
|
||||
func autoExtractStreamingResponseFromSSE(data []byte, fallbackPaths []string) string {
|
||||
if len(fallbackPaths) == 0 {
|
||||
return ""
|
||||
}
|
||||
chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n"))
|
||||
var parts []string
|
||||
for _, chunk := range chunks {
|
||||
if s := autoExtractStreamingResponseContent(chunk, fallbackPaths); len(s) > 0 {
|
||||
parts = append(parts, s)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user