fix: Fix the quotation issue of deny message in ai-security-guard (#1352)

This commit is contained in:
Kent Dong
2024-09-27 18:45:51 +08:00
committed by GitHub
parent 1b119ed371
commit 71aae9ddf6

View File

@@ -7,6 +7,7 @@ import (
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
mrand "math/rand" mrand "math/rand"
@@ -194,7 +195,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
} }
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
proxywasm.LogDebugf("checking request body...") log.Debugf("checking request body...")
content := gjson.GetBytes(body, config.requestContentJsonPath).String() content := gjson.GetBytes(body, config.requestContentJsonPath).String()
model := gjson.GetBytes(body, "model").Raw model := gjson.GetBytes(body, "model").Raw
ctx.SetContext("requestModel", model) ctx.SetContext("requestModel", model)
@@ -231,12 +232,14 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
respAdvice := respData.Get("Advice") respAdvice := respData.Get("Advice")
respResult := respData.Get("Result") respResult := respData.Get("Result")
var denyMessage string var denyMessage string
messageNeedSerialization := true
if config.protocolOriginal { if config.protocolOriginal {
// not openai // not openai
if config.denyMessage != "" { if config.denyMessage != "" {
denyMessage = config.denyMessage denyMessage = config.denyMessage
} else if respAdvice.Exists() { } else if respAdvice.Exists() {
denyMessage = respAdvice.Array()[0].Get("Answer").Raw denyMessage = respAdvice.Array()[0].Get("Answer").Raw
messageNeedSerialization = false
} else { } else {
denyMessage = DefaultDenyMessage denyMessage = DefaultDenyMessage
} }
@@ -244,12 +247,20 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
// openai // openai
if respAdvice.Exists() { if respAdvice.Exists() {
denyMessage = respAdvice.Array()[0].Get("Answer").Raw denyMessage = respAdvice.Array()[0].Get("Answer").Raw
messageNeedSerialization = false
} else if config.denyMessage != "" { } else if config.denyMessage != "" {
denyMessage = config.denyMessage denyMessage = config.denyMessage
} else { } else {
denyMessage = DefaultDenyMessage denyMessage = DefaultDenyMessage
} }
} }
if messageNeedSerialization {
if data, err := json.Marshal(denyMessage); err == nil {
denyMessage = string(data)
} else {
denyMessage = fmt.Sprintf("\"%s\"", DefaultDenyMessage)
}
}
if respResult.Array()[0].Get("Label").String() != "nonLabel" { if respResult.Array()[0].Get("Label").String() != "nonLabel" {
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String())) proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request")) proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request"))
@@ -280,7 +291,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
} }
return types.ActionPause return types.ActionPause
} else { } else {
proxywasm.LogDebugf("request content is empty. skip") log.Debugf("request content is empty. skip")
return types.ActionContinue return types.ActionContinue
} }
} }
@@ -320,7 +331,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
} }
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
proxywasm.LogDebugf("checking response body...") log.Debugf("checking response body...")
hdsMap := ctx.GetContext("headers").(map[string][]string) hdsMap := ctx.GetContext("headers").(map[string][]string)
isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream")
model := ctx.GetStringContext("requestModel", "unknown") model := ctx.GetStringContext("requestModel", "unknown")
@@ -411,7 +422,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
} }
return types.ActionPause return types.ActionPause
} else { } else {
proxywasm.LogDebugf("request content is empty. skip") log.Debugf("request content is empty. skip")
return types.ActionContinue return types.ActionContinue
} }
} }