Files
higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text/openai.go
JianweiWang c21a38e783 feat(ai-security-guard): structured x_higress deny response, error-path metrics, and AI logging (#3894)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: rinfx <yucheng.lxr@alibaba-inc.com>
2026-05-29 10:45:10 +08:00

453 lines
17 KiB
Go

package text
import (
"bytes"
"encoding/json"
"net/http"
"strings"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
responseFallbackPathsCtxKey = "response_fallback_paths"
responseStreamFallbackPathsCtxKey = "response_stream_fallback_paths"
responseStartTimeCtxKey = "response_start_time"
)
func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
ctx.SetContext("end_of_stream_received", false)
ctx.SetContext("during_call", false)
ctx.SetContext("risk_detected", false)
ctx.SetContext(responseStartTimeCtxKey, time.Now().UnixMilli())
ctx.SetContext(responseFallbackPathsCtxKey, buildEffectiveFallbackPaths(config.ResponseContentJsonPath, config.ResponseContentFallbackJsonPaths))
ctx.SetContext(responseStreamFallbackPathsCtxKey, buildEffectiveFallbackPaths(config.ResponseStreamContentJsonPath, config.ResponseStreamContentFallbackJsonPaths))
sessionID, _ := utils.GenerateHexID(20)
ctx.SetContext("sessionID", sessionID)
if strings.Contains(contentType, "text/event-stream") {
ctx.NeedPauseStreamingResponse()
return types.ActionContinue
} else {
ctx.BufferResponseBody()
return types.HeaderStopIteration
}
}
func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
consumer, _ := ctx.GetContext("consumer").(string)
streamFallbackPaths := getEffectiveFallbackPathsFromContext(ctx, responseStreamFallbackPathsCtxKey, config.ResponseStreamContentJsonPath, config.ResponseStreamContentFallbackJsonPaths)
var sessionID string
if ctx.GetContext("sessionID") == nil {
sessionID, _ = utils.GenerateHexID(20)
ctx.SetContext("sessionID", sessionID)
} else {
sessionID, _ = ctx.GetContext("sessionID").(string)
}
var bufferQueue [][]byte
currentSubmissionIndex := 0
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
startTime, _ := ctx.GetContext(responseStartTimeCtxKey).(int64)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
if ctx.GetContext("end_of_stream_received").(bool) {
proxywasm.ResumeHttpResponse()
}
ctx.SetContext("during_call", false)
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at response phase")
startTime, _ := ctx.GetContext(responseStartTimeCtxKey).(int64)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
if ctx.GetContext("end_of_stream_received").(bool) {
proxywasm.ResumeHttpResponse()
}
ctx.SetContext("during_call", false)
return
}
if !cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
jsonData, err := cfg.BuildOpenAIDenyData(config, response, consumer, true)
if err != nil {
// Build failure → fail-open: inject the buffered upstream content as-is.
// Make this path observable so operators can spot the silent passthrough
// instead of mistakenly attributing observed denies-only to the success
// path's metrics. Symmetric with the success path's observability suite
// (counter / safecheck_response_rt / safecheck_status / log / risk_detected).
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultError)
log.Errorf("failed to build deny response body: %v", err)
endStream := ctx.GetContext("end_of_stream_received").(bool) && ctx.BufferQueueSize() == 0
proxywasm.InjectEncodedDataToFilterChain(bytes.Join(bufferQueue, []byte("")), endStream)
bufferQueue = [][]byte{}
config.IncrementCounter("ai_sec_response_deny_buildfail", 1)
startTime, _ := ctx.GetContext(responseStartTimeCtxKey).(int64)
ctx.SetUserAttribute("safecheck_response_rt", time.Now().UnixMilli()-startTime)
ctx.SetUserAttribute("safecheck_status", "build_fallback_pass")
if len(response.Data.Result) > 0 {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
if !endStream {
ctx.SetContext("during_call", false)
singleCall()
}
return
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
proxywasm.InjectEncodedDataToFilterChain(jsonData, true)
ctx.SetContext("risk_detected", true)
ctx.SetContext("during_call", false)
config.IncrementCounter("ai_sec_response_deny", 1)
startTime, _ := ctx.GetContext(responseStartTimeCtxKey).(int64)
ctx.SetUserAttribute("safecheck_response_rt", time.Now().UnixMilli()-startTime)
ctx.SetUserAttribute("safecheck_status", "response deny")
if len(response.Data.Result) > 0 {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
return
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass)
cfg.WriteGuardrailLog(ctx)
endStream := ctx.GetContext("end_of_stream_received").(bool) && ctx.BufferQueueSize() == 0
proxywasm.InjectEncodedDataToFilterChain(bytes.Join(bufferQueue, []byte("")), endStream)
bufferQueue = [][]byte{}
if !endStream {
ctx.SetContext("during_call", false)
singleCall()
}
}
singleCall = func() {
if ctx.GetContext("during_call").(bool) {
return
}
if ctx.BufferQueueSize() >= config.BufferLimit || ctx.GetContext("end_of_stream_received").(bool) {
var buffer string
for ctx.BufferQueueSize() > 0 {
front := ctx.PopBuffer()
bufferQueue = append(bufferQueue, front)
msg := gjson.GetBytes(front, config.ResponseStreamContentJsonPath).String()
if len(msg) == 0 {
msg = autoExtractStreamingResponseContent(front, streamFallbackPaths)
}
buffer += msg
if len([]rune(buffer)) >= config.BufferLimit {
break
}
}
// case 1: streaming body has reasoning_content, part of buffer maybe empty
// case 2: streaming body has toolcall result, part of buffer maybe empty
log.Debugf("current content piece: %s", buffer)
if len(buffer) == 0 {
buffer = "[empty content]"
}
ctx.SetContext("during_call", true)
log.Debugf("current content piece: %s", buffer)
checkService := config.GetResponseCheckService(consumer)
currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseResponse, cfg.GuardrailModalityText)
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, buffer, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
startTime, _ := ctx.GetContext(responseStartTimeCtxKey).(int64)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, nil, startTime)
if ctx.GetContext("end_of_stream_received").(bool) {
proxywasm.ResumeHttpResponse()
}
}
}
}
if !ctx.GetContext("risk_detected").(bool) {
unifiedChunk := wrapper.UnifySSEChunk(data)
hasTrailingSeparator := bytes.HasSuffix(unifiedChunk, []byte("\n\n"))
trimmedChunk := bytes.TrimSpace(unifiedChunk)
chunks := bytes.Split(trimmedChunk, []byte("\n\n"))
// Filter out empty chunks
nonEmptyChunks := make([][]byte, 0, len(chunks))
for _, chunk := range chunks {
if len(chunk) > 0 {
nonEmptyChunks = append(nonEmptyChunks, chunk)
}
}
// Restore separators
for i := range len(nonEmptyChunks) - 1 {
nonEmptyChunks[i] = append(nonEmptyChunks[i], []byte("\n\n")...)
}
if hasTrailingSeparator && len(nonEmptyChunks) > 0 {
nonEmptyChunks[len(nonEmptyChunks)-1] = append(nonEmptyChunks[len(nonEmptyChunks)-1], []byte("\n\n")...)
}
for _, chunk := range nonEmptyChunks {
ctx.PushBuffer(chunk)
}
// for _, chunk := range bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) {
// ctx.PushBuffer([]byte(string(chunk) + "\n\n"))
// }
ctx.SetContext("end_of_stream_received", endOfStream)
if !ctx.GetContext("during_call").(bool) {
singleCall()
}
} else if endOfStream {
proxywasm.ResumeHttpResponse()
}
return []byte{}
}
func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
responseFallbackPaths := getEffectiveFallbackPathsFromContext(ctx, responseFallbackPathsCtxKey, config.ResponseContentJsonPath, config.ResponseContentFallbackJsonPaths)
streamFallbackPaths := getEffectiveFallbackPathsFromContext(ctx, responseStreamFallbackPathsCtxKey, config.ResponseStreamContentJsonPath, config.ResponseStreamContentFallbackJsonPaths)
log.Debugf("checking response body...")
startTime := time.Now().UnixMilli()
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
isStreamingResponse := strings.Contains(contentType, "event-stream")
var content string
if isStreamingResponse {
content = utils.ExtractMessageFromStreamingBody(body, config.ResponseStreamContentJsonPath)
if len(content) == 0 {
content = autoExtractStreamingResponseFromSSE(body, streamFallbackPaths)
}
} else {
content = gjson.GetBytes(body, config.ResponseContentJsonPath).String()
if len(content) == 0 {
content = autoExtractResponseContent(body, responseFallbackPaths)
}
}
log.Debugf("Raw response content is: %s", content)
if len(content) == 0 {
log.Info("response content is empty. skip")
return types.ActionContinue
}
contentIndex := 0
sessionID, _ := utils.GenerateHexID(20)
currentSubmissionIndex := 0
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpResponse()
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at response phase")
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpResponse()
return
}
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if contentIndex >= len(content) {
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "response pass")
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultPass)
if contentIndex >= len(content) {
cfg.WriteGuardrailLog(ctx)
proxywasm.ResumeHttpResponse()
} else {
singleCall()
}
return
}
if err := cfg.SendDenyResponse(config, response, consumer, isStreamingResponse); err != nil {
log.Errorf("failed to build deny response body: %v", err)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, responseBody, startTime)
proxywasm.ResumeHttpResponse()
return
}
config.IncrementCounter("ai_sec_response_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "response deny")
if len(response.Data.Result) > 0 {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
cfg.CompleteGuardrailSubmissionEvent(ctx, currentSubmissionIndex, responseBody, cfg.GuardrailResultDeny)
cfg.WriteGuardrailLog(ctx)
}
singleCall = func() {
var nextContentIndex int
if contentIndex+cfg.LengthLimit >= len(content) {
nextContentIndex = len(content)
} else {
nextContentIndex = contentIndex + cfg.LengthLimit
}
contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
checkService := config.GetResponseCheckService(consumer)
currentSubmissionIndex = cfg.BeginGuardrailSubmissionEvent(ctx, cfg.GuardrailPhaseResponse, cfg.GuardrailModalityText)
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, contentPiece, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
cfg.MarkGuardrailResponseError(ctx, currentSubmissionIndex, nil, startTime)
proxywasm.ResumeHttpResponse()
}
}
singleCall()
return types.ActionPause
}
// autoExtractResponseContent tries configured fallback paths to extract text content.
func autoExtractResponseContent(body []byte, fallbackPaths []string) string {
if len(fallbackPaths) == 0 {
return ""
}
parsed := gjson.ParseBytes(body)
return extractTextByPaths(parsed, fallbackPaths)
}
// autoExtractStreamingResponseContent tries configured fallback paths to extract text content.
// It handles both bare JSON and SSE "data:" payloads, including multi-line data events.
func autoExtractStreamingResponseContent(chunk []byte, fallbackPaths []string) string {
if len(fallbackPaths) == 0 {
return ""
}
payload := bytes.TrimSpace(chunk)
if len(payload) == 0 {
return ""
}
if !isJSONPayload(payload) {
payload = extractSSEDataPayload(payload)
if len(payload) == 0 {
return ""
}
}
if !json.Valid(payload) {
return ""
}
parsed := gjson.ParseBytes(payload)
return extractTextByPaths(parsed, fallbackPaths)
}
func isJSONPayload(payload []byte) bool {
return len(payload) > 0 && (payload[0] == '{' || payload[0] == '[')
}
// extractSSEDataPayload concatenates all "data:" lines in one SSE event.
// SSE specifies multi-line data fields should be joined with '\n'.
func extractSSEDataPayload(chunk []byte) []byte {
lines := bytes.Split(chunk, []byte("\n"))
dataLines := make([][]byte, 0, len(lines))
for _, line := range lines {
line = bytes.TrimSpace(line)
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
data := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
return nil
}
dataLines = append(dataLines, data)
}
if len(dataLines) == 0 {
return nil
}
return bytes.TrimSpace(bytes.Join(dataLines, []byte("\n")))
}
func buildEffectiveFallbackPaths(primaryPath string, fallbackPaths []string) []string {
primaryPath = strings.TrimSpace(primaryPath)
if len(fallbackPaths) == 0 {
return []string{}
}
deduped := make([]string, 0, len(fallbackPaths))
seen := make(map[string]struct{}, len(fallbackPaths))
for _, path := range fallbackPaths {
path = strings.TrimSpace(path)
if len(path) == 0 || path == primaryPath {
continue
}
if _, ok := seen[path]; ok {
continue
}
seen[path] = struct{}{}
deduped = append(deduped, path)
}
if len(deduped) == 0 {
return []string{}
}
return deduped
}
type fallbackPathContext interface {
GetContext(key string) interface{}
SetContext(key string, value interface{})
}
func getEffectiveFallbackPathsFromContext(ctx fallbackPathContext, ctxKey string, primaryPath string, fallbackPaths []string) []string {
if cached, ok := ctx.GetContext(ctxKey).([]string); ok {
return cached
}
effective := buildEffectiveFallbackPaths(primaryPath, fallbackPaths)
ctx.SetContext(ctxKey, effective)
return effective
}
func extractTextByPaths(parsed gjson.Result, paths []string) string {
for _, path := range paths {
path = strings.TrimSpace(path)
if len(path) == 0 {
continue
}
result := parsed.Get(path)
if !result.Exists() {
continue
}
if text := extractTextFromResult(result); len(text) > 0 {
log.Debugf("response fallback path matched: %s", path)
return text
}
}
return ""
}
func extractTextFromResult(result gjson.Result) string {
if result.IsArray() {
var parts []string
for _, item := range result.Array() {
if s := item.String(); len(s) > 0 {
parts = append(parts, s)
}
}
return strings.Join(parts, "")
}
return result.String()
}
// autoExtractStreamingResponseFromSSE tries configured fallback paths on a full SSE body.
func autoExtractStreamingResponseFromSSE(data []byte, fallbackPaths []string) string {
if len(fallbackPaths) == 0 {
return ""
}
chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n"))
var parts []string
for _, chunk := range chunks {
if s := autoExtractStreamingResponseContent(chunk, fallbackPaths); len(s) > 0 {
parts = append(parts, s)
}
}
return strings.Join(parts, "")
}