mirror of
https://github.com/alibaba/higress.git
synced 2026-05-25 13:17:28 +08:00
feat(ai-security-guard): enhance risk action resolution and support sensitive data masking (#3690)
Co-authored-by: rinfx <yucheng.lxr@alibaba-inc.com>
This commit is contained in:
@@ -64,9 +64,13 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
|
||||
}
|
||||
contentIndex := 0
|
||||
imageIndex := 0
|
||||
hasMasked := false
|
||||
maskedContent := []byte(content)
|
||||
sessionID, _ := utils.GenerateHexID(20)
|
||||
var singleCall func()
|
||||
var singleCallForImage func()
|
||||
// prevContentIndex tracks the start of the current chunk for masking replacement
|
||||
prevContentIndex := 0
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
@@ -80,11 +84,47 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if contentIndex >= len(content) {
|
||||
riskResult := cfg.EvaluateRisk(config.Action, response.Data, config, consumer)
|
||||
proxywasm.LogInfof("safecheck_resolved_action=%v", riskResult)
|
||||
switch riskResult {
|
||||
case cfg.RiskPass:
|
||||
if contentIndex >= len(maskedContent) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
if hasMasked {
|
||||
// All chunks processed, some had masking — replace the content in request body
|
||||
newBody, replaceErr := utils.ReplaceJsonFieldTextContent(body, config.RequestContentJsonPath, string(maskedContent))
|
||||
if replaceErr != nil {
|
||||
log.Errorf("failed to replace request body content, falling back to block: %v", replaceErr)
|
||||
// Fall back to block to prevent leaking sensitive data
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
if config.ProtocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
return
|
||||
}
|
||||
proxywasm.ReplaceHttpRequestBody(newBody)
|
||||
config.IncrementCounter("ai_sec_request_mask", 1)
|
||||
ctx.SetUserAttribute("safecheck_status", "request mask")
|
||||
} else {
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
if len(images) > 0 && config.CheckRequestImage {
|
||||
singleCallForImage()
|
||||
@@ -95,45 +135,110 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
case cfg.RiskMask:
|
||||
desensitization := cfg.ExtractDesensitization(response.Data)
|
||||
if desensitization == "" {
|
||||
proxywasm.LogInfof("safecheck_action_source=mask_fallback_to_block, reason=empty_desensitization")
|
||||
log.Warnf("desensitization content is empty, falling back to block logic")
|
||||
} else {
|
||||
// Replace only the current chunk portion in maskedContent
|
||||
chunkStart := prevContentIndex
|
||||
chunkEnd := contentIndex
|
||||
maskedContent = append(maskedContent[:chunkStart], append([]byte(desensitization), maskedContent[chunkEnd:]...)...)
|
||||
// Adjust contentIndex for the length difference
|
||||
lengthDiff := len(desensitization) - (chunkEnd - chunkStart)
|
||||
contentIndex += lengthDiff
|
||||
hasMasked = true
|
||||
// Continue checking remaining chunks
|
||||
if contentIndex >= len(maskedContent) {
|
||||
// All chunks done, apply the masked content
|
||||
newBody, replaceErr := utils.ReplaceJsonFieldTextContent(body, config.RequestContentJsonPath, string(maskedContent))
|
||||
if replaceErr != nil {
|
||||
log.Errorf("failed to replace request body content, falling back to block: %v", replaceErr)
|
||||
// Fall back to block to prevent leaking sensitive data
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
if config.ProtocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
return
|
||||
}
|
||||
proxywasm.ReplaceHttpRequestBody(newBody)
|
||||
config.IncrementCounter("ai_sec_request_mask", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request mask")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
if len(images) > 0 && config.CheckRequestImage {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
// Fall through to block logic when desensitization is empty
|
||||
fallthrough
|
||||
case cfg.RiskBlock:
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
if config.ProtocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
if config.ProtocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
prevContentIndex = contentIndex
|
||||
var nextContentIndex int
|
||||
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||
nextContentIndex = len(content)
|
||||
if contentIndex+cfg.LengthLimit >= len(maskedContent) {
|
||||
nextContentIndex = len(maskedContent)
|
||||
} else {
|
||||
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||
}
|
||||
contentPiece := content[contentIndex:nextContentIndex]
|
||||
contentPiece := string(maskedContent[contentIndex:nextContentIndex])
|
||||
contentIndex = nextContentIndex
|
||||
log.Debugf("current content piece: %s", contentPiece)
|
||||
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)
|
||||
|
||||
Reference in New Issue
Block a user