diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index de87119d3..3051fe22e 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -7,6 +7,7 @@ import ( "crypto/sha1" "encoding/base64" "encoding/hex" + "encoding/json" "errors" "fmt" mrand "math/rand" @@ -194,7 +195,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log } func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { - proxywasm.LogDebugf("checking request body...") + log.Debugf("checking request body...") content := gjson.GetBytes(body, config.requestContentJsonPath).String() model := gjson.GetBytes(body, "model").Raw ctx.SetContext("requestModel", model) @@ -231,12 +232,14 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] respAdvice := respData.Get("Advice") respResult := respData.Get("Result") var denyMessage string + messageNeedSerialization := true if config.protocolOriginal { // not openai if config.denyMessage != "" { denyMessage = config.denyMessage } else if respAdvice.Exists() { denyMessage = respAdvice.Array()[0].Get("Answer").Raw + messageNeedSerialization = false } else { denyMessage = DefaultDenyMessage } @@ -244,12 +247,20 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] // openai if respAdvice.Exists() { denyMessage = respAdvice.Array()[0].Get("Answer").Raw + messageNeedSerialization = false } else if config.denyMessage != "" { denyMessage = config.denyMessage } else { denyMessage = DefaultDenyMessage } } + if messageNeedSerialization { + if data, err := json.Marshal(denyMessage); err == nil { + denyMessage = string(data) + } else { + denyMessage = fmt.Sprintf("\"%s\"", DefaultDenyMessage) + } + } if respResult.Array()[0].Get("Label").String() != "nonLabel" { proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String())) proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request")) @@ -280,7 +291,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] } return types.ActionPause } else { - proxywasm.LogDebugf("request content is empty. skip") + log.Debugf("request content is empty. skip") return types.ActionContinue } } @@ -320,7 +331,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log } func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { - proxywasm.LogDebugf("checking response body...") + log.Debugf("checking response body...") hdsMap := ctx.GetContext("headers").(map[string][]string) isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") model := ctx.GetStringContext("requestModel", "unknown") @@ -411,7 +422,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ } return types.ActionPause } else { - proxywasm.LogDebugf("request content is empty. skip") + log.Debugf("request content is empty. skip") return types.ActionContinue } }