diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index 02660fad8..f4aee5632 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -305,7 +305,12 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] } ctx.DontReadResponseBody() config.incrementCounter("ai_sec_request_deny", 1) - proxywasm.ResumeHttpRequest() + ctx.SetUserAttribute("safecheck_status", "request deny") + if response.Data.Advice != nil { + ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) + ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) + } + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) } singleCall = func() { timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") @@ -385,6 +390,11 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log return types.ActionContinue } hdsMap := convertHeaders(headers) + if !strings.Contains(strings.Join(hdsMap[":status"], ";"), "200") { + log.Debugf("response is not 200, skip response body check") + ctx.DontReadResponseBody() + return types.ActionContinue + } ctx.SetContext("headers", hdsMap) return types.HeaderStopIteration } @@ -436,22 +446,24 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ denyMessage = response.Data.Advice[0].Answer } marshalledDenyMessage := marshalStr(denyMessage, log) - var jsonData []byte if config.protocolOriginal { - jsonData = []byte(marshalledDenyMessage) + proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) } else if isStreamingResponse { randomID := generateRandomID() - jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model)) + jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model)) + proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) } else { randomID := generateRandomID() - jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage)) + jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage)) + proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) } - delete(hdsMap, "content-length") - hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)} - proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap)) - proxywasm.ReplaceHttpResponseBody(jsonData) config.incrementCounter("ai_sec_response_deny", 1) - proxywasm.ResumeHttpResponse() + ctx.SetUserAttribute("safecheck_status", "response deny") + if response.Data.Advice != nil { + ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) + ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) + } + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) } singleCall = func() { timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")