feat(ai-security-guard): enhance risk action resolution and support sensitive data masking (#3690)

Co-authored-by: rinfx <yucheng.lxr@alibaba-inc.com>
This commit is contained in:
JianweiWang
2026-04-15 11:14:56 +08:00
committed by GitHub
parent e2beb6cd45
commit b1187cc14d
13 changed files with 5019 additions and 214 deletions

View File

@@ -64,9 +64,13 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
}
contentIndex := 0
imageIndex := 0
hasMasked := false
maskedContent := []byte(content)
sessionID, _ := utils.GenerateHexID(20)
var singleCall func()
var singleCallForImage func()
// prevContentIndex tracks the start of the current chunk for masking replacement
prevContentIndex := 0
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
@@ -80,11 +84,47 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
proxywasm.ResumeHttpRequest()
return
}
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if contentIndex >= len(content) {
riskResult := cfg.EvaluateRisk(config.Action, response.Data, config, consumer)
proxywasm.LogInfof("safecheck_resolved_action=%v", riskResult)
switch riskResult {
case cfg.RiskPass:
if contentIndex >= len(maskedContent) {
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
if hasMasked {
// All chunks processed, some had masking — replace the content in request body
newBody, replaceErr := utils.ReplaceJsonFieldTextContent(body, config.RequestContentJsonPath, string(maskedContent))
if replaceErr != nil {
log.Errorf("failed to replace request body content, falling back to block: %v", replaceErr)
// Fall back to block to prevent leaking sensitive data
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
return
}
proxywasm.ReplaceHttpRequestBody(newBody)
config.IncrementCounter("ai_sec_request_mask", 1)
ctx.SetUserAttribute("safecheck_status", "request mask")
} else {
ctx.SetUserAttribute("safecheck_status", "request pass")
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
if len(images) > 0 && config.CheckRequestImage {
singleCallForImage()
@@ -95,45 +135,110 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
singleCall()
}
return
case cfg.RiskMask:
desensitization := cfg.ExtractDesensitization(response.Data)
if desensitization == "" {
proxywasm.LogInfof("safecheck_action_source=mask_fallback_to_block, reason=empty_desensitization")
log.Warnf("desensitization content is empty, falling back to block logic")
} else {
// Replace only the current chunk portion in maskedContent
chunkStart := prevContentIndex
chunkEnd := contentIndex
maskedContent = append(maskedContent[:chunkStart], append([]byte(desensitization), maskedContent[chunkEnd:]...)...)
// Adjust contentIndex for the length difference
lengthDiff := len(desensitization) - (chunkEnd - chunkStart)
contentIndex += lengthDiff
hasMasked = true
// Continue checking remaining chunks
if contentIndex >= len(maskedContent) {
// All chunks done, apply the masked content
newBody, replaceErr := utils.ReplaceJsonFieldTextContent(body, config.RequestContentJsonPath, string(maskedContent))
if replaceErr != nil {
log.Errorf("failed to replace request body content, falling back to block: %v", replaceErr)
// Fall back to block to prevent leaking sensitive data
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
return
}
proxywasm.ReplaceHttpRequestBody(newBody)
config.IncrementCounter("ai_sec_request_mask", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request mask")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
if len(images) > 0 && config.CheckRequestImage {
singleCallForImage()
} else {
proxywasm.ResumeHttpRequest()
}
} else {
singleCall()
}
return
}
// Fall through to block logic when desensitization is empty
fallthrough
case cfg.RiskBlock:
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
if err != nil {
log.Errorf("failed to build deny response body: %v", err)
proxywasm.ResumeHttpRequest()
return
}
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
if err != nil {
log.Errorf("failed to build deny response body: %v", err)
proxywasm.ResumeHttpRequest()
return
}
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
prevContentIndex = contentIndex
var nextContentIndex int
if contentIndex+cfg.LengthLimit >= len(content) {
nextContentIndex = len(content)
if contentIndex+cfg.LengthLimit >= len(maskedContent) {
nextContentIndex = len(maskedContent)
} else {
nextContentIndex = contentIndex + cfg.LengthLimit
}
contentPiece := content[contentIndex:nextContentIndex]
contentPiece := string(maskedContent[contentIndex:nextContentIndex])
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)