From 2cb8558cda1570ff71edba403d462e95451d52c7 Mon Sep 17 00:00:00 2001 From: rinfx <893383980@qq.com> Date: Mon, 11 Nov 2024 14:49:17 +0800 Subject: [PATCH] Optimize AI security guard plugin (#1473) Co-authored-by: Kent Dong --- .../extensions/ai-security-guard/README.md | 26 +- .../extensions/ai-security-guard/go.mod | 2 +- .../extensions/ai-security-guard/go.sum | 4 +- .../extensions/ai-security-guard/main.go | 411 ++++++++++-------- 4 files changed, 242 insertions(+), 201 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-security-guard/README.md b/plugins/wasm-go/extensions/ai-security-guard/README.md index 122d9e367..68eeeae20 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README.md @@ -30,19 +30,23 @@ description: 阿里云内容安全检测 | `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 | | `denyMessage` | string | optional | openai格式的流式/非流式响应 | 指定内容非法时的响应内容 | | `protocol` | string | optional | openai | 协议格式,非openai协议填`original` | +| `riskLevelBar` | string | optional | high | 拦截风险等级,取值为 max, high, medium, low | -补充说明一下 `denyMessage`,对于openai格式的请求,对非法请求的处理逻辑为: -- 如果配置了 `denyMessage` - - 优先返回阿里云内容安全的建议回答,格式为openai格式的流式/非流式响应 - - 如果阿里云内容安全未返回建议的回答,返回内容为 `denyMessage` 配置内容,格式为openai格式的流式/非流式响应 -- 如果没有配置 `denyMessage` - - 优先返回阿里云内容安全的建议回答,格式为openai格式的流式/非流式响应 - - 如果阿里云内容安全未返回建议的回答,返回内容为内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,格式为openai格式的流式/非流式响应 +补充说明一下 `denyMessage`,对非法请求的处理逻辑为: +- 如果配置了 `denyMessage`,返回内容为 `denyMessage` 配置内容,格式为openai格式的流式/非流式响应 +- 如果没有配置 `denyMessage`,优先返回阿里云内容安全的建议回答,格式为openai格式的流式/非流式响应 +- 如果阿里云内容安全未返回建议的回答,返回内容为内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,格式为openai格式的流式/非流式响应 -如果用户使用了非openai格式的协议,应当配置 `denyMessage`,此时对非法请求的处理逻辑为: -- 返回用户配置的 `denyMessage` 内容,用户可以配置其为序列化后的json字符串,以保持与正常请求接口返回格式的一致性 -- 如果 `denyMessage` 为空,优先返回阿里云内容安全的建议回答,格式为纯文本 -- 如果阿里云内容安全未返回建议回答,返回内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,格式为纯文本 +如果用户使用了非openai格式的协议,此时对非法请求的处理逻辑为: +- 如果配置了 `denyMessage`,返回用户配置的 `denyMessage` 内容,非流式响应 +- 如果没有配置 `denyMessage`,优先返回阿里云内容安全的建议回答,非流式响应 +- 如果阿里云内容安全未返回建议回答,返回内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,非流式响应 + +补充说明一下 `riskLevelBar` 的四个等级: +- `max`: 检测请求/响应内容,但是不会产生拦截行为 +- `high`: 内容安全检测结果中风险等级为 `high` 时产生拦截 +- `medium`: 内容安全检测结果中风险等级 >= `medium` 时产生拦截 +- `low`: 内容安全检测结果中风险等级 >= `low` 时产生拦截 ## 配置示例 ### 前提条件 diff --git a/plugins/wasm-go/extensions/ai-security-guard/go.mod b/plugins/wasm-go/extensions/ai-security-guard/go.mod index f2bc5a143..3ab8ec183 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/go.mod +++ b/plugins/wasm-go/extensions/ai-security-guard/go.mod @@ -7,7 +7,7 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../.. require ( github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.17.3 ) require ( diff --git a/plugins/wasm-go/extensions/ai-security-guard/go.sum b/plugins/wasm-go/extensions/ai-security-guard/go.sum index f473e12b2..042eae70f 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/go.sum +++ b/plugins/wasm-go/extensions/ai-security-guard/go.sum @@ -9,8 +9,8 @@ github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index 5b6158961..cf7cdead0 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -35,12 +35,16 @@ func main() { } const ( - OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":%s,"choices":[{"index":0,"message":{"role":"assistant","content":%s},"logprobs":null,"finish_reason":"stop"}]}` - OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":%s,"choices":[{"index":0,"delta":{"role":"assistant","content":%s},"logprobs":null,"finish_reason":null}]}` - OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model": %s,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}` - OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]` + MaxRisk = "max" + HighRisk = "high" + MediumRisk = "medium" + LowRisk = "low" + NoRisk = "none" - TracingPrefix = "trace_span_tag." + OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}]}` + OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}` + OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}` + OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]` DefaultRequestCheckService = "llm_query_moderation" DefaultResponseCheckService = "llm_response_moderation" @@ -53,10 +57,37 @@ const ( AliyunUserAgent = "CIPFrom/AIGateway" ) +type Response struct { + Code int `json:"Code"` + Message string `json:"Message"` + RequestId string `json:"RequestId"` + Data Data `json:"Data"` +} + +type Data struct { + RiskLevel string `json:"RiskLevel"` + Result []Result `json:"Result,omitempty"` + Advice []Advice `json:"Advice,omitempty"` +} + +type Result struct { + RiskWords string `json:"RiskWords,omitempty"` + Description string `json:"Description,omitempty"` + Confidence float64 `json:"Confidence,omitempty"` + Label string `json:"Label,omitempty"` +} + +type Advice struct { + Answer string `json:"Answer,omitempty"` + HitLabel string `json:"HitLabel,omitempty"` + HitLibName string `json:"HitLibName,omitempty"` +} + type AISecurityConfig struct { client wrapper.HttpClient ak string sk string + token string checkRequest bool requestCheckService string requestContentJsonPath string @@ -67,6 +98,7 @@ type AISecurityConfig struct { denyCode int64 denyMessage string protocolOriginal bool + riskLevelBar string metrics map[string]proxywasm.MetricCounter } @@ -79,12 +111,31 @@ func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) counter.Increment(inc) } +func riskLevelToInt(riskLevel string) int { + switch riskLevel { + case MaxRisk: + return 4 + case HighRisk: + return 3 + case MediumRisk: + return 2 + case LowRisk: + return 1 + case NoRisk: + return 0 + default: + return -1 + } +} + func urlEncoding(rawStr string) string { encodedStr := url.PathEscape(rawStr) encodedStr = strings.ReplaceAll(encodedStr, "+", "%2B") encodedStr = strings.ReplaceAll(encodedStr, ":", "%3A") encodedStr = strings.ReplaceAll(encodedStr, "=", "%3D") encodedStr = strings.ReplaceAll(encodedStr, "&", "%26") + encodedStr = strings.ReplaceAll(encodedStr, "$", "%24") + encodedStr = strings.ReplaceAll(encodedStr, "@", "%40") return encodedStr } @@ -130,6 +181,7 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) e if config.ak == "" || config.sk == "" { return errors.New("invalid AK/SK config") } + config.token = json.Get("securityToken").String() config.checkRequest = json.Get("checkRequest").Bool() config.checkResponse = json.Get("checkResponse").Bool() config.protocolOriginal = json.Get("protocol").String() == "original" @@ -164,6 +216,14 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) e } else { config.responseStreamContentJsonPath = DefaultStreamingResponseJsonPath } + if obj := json.Get("riskLevelBar"); obj.Exists() { + config.riskLevelBar = obj.String() + if riskLevelToInt(config.riskLevelBar) <= 0 { + return errors.New("invalid risk level, value must be one of [max, high, medium, low]") + } + } else { + config.riskLevelBar = HighRisk + } config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{ FQDN: serviceName, Port: servicePort, @@ -192,105 +252,82 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { log.Debugf("checking request body...") - content := gjson.GetBytes(body, config.requestContentJsonPath).Raw - model := gjson.GetBytes(body, "model").Raw + content := gjson.GetBytes(body, config.requestContentJsonPath).String() + model := gjson.GetBytes(body, "model").String() ctx.SetContext("requestModel", model) log.Debugf("Raw request content is: %s", content) - if len(content) > 0 { - timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") - randomID, _ := generateHexID(16) - params := map[string]string{ - "Format": "JSON", - "Version": "2022-03-02", - "SignatureMethod": "Hmac-SHA1", - "SignatureNonce": randomID, - "SignatureVersion": "1.0", - "Action": "TextModerationPlus", - "AccessKeyId": config.ak, - "Timestamp": timestamp, - "Service": config.requestCheckService, - "ServiceParameters": fmt.Sprintf(`{"content": %s}`, content), - } - signature := getSign(params, config.sk+"&") - reqParams := url.Values{} - for k, v := range params { - reqParams.Add(k, v) - } - reqParams.Add("Signature", signature) - err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, - func(statusCode int, responseHeaders http.Header, responseBody []byte) { - if statusCode != 200 { - log.Error(string(responseBody)) - proxywasm.ResumeHttpRequest() - return - } - respData := gjson.GetBytes(responseBody, "Data") - if respData.Exists() { - respAdvice := respData.Get("Advice") - respResult := respData.Get("Result") - var denyMessage string - messageNeedSerialization := true - if config.protocolOriginal { - // not openai - if config.denyMessage != "" { - denyMessage = config.denyMessage - } else if respAdvice.Exists() { - denyMessage = respAdvice.Array()[0].Get("Answer").Raw - messageNeedSerialization = false - } else { - denyMessage = DefaultDenyMessage - } - } else { - // openai - if respAdvice.Exists() { - denyMessage = respAdvice.Array()[0].Get("Answer").Raw - messageNeedSerialization = false - } else if config.denyMessage != "" { - denyMessage = config.denyMessage - } else { - denyMessage = DefaultDenyMessage - } - } - if messageNeedSerialization { - if data, err := json.Marshal(denyMessage); err == nil { - denyMessage = string(data) - } else { - denyMessage = fmt.Sprintf("\"%s\"", DefaultDenyMessage) - } - } - if respResult.Array()[0].Get("Label").String() != "nonLabel" { - proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String())) - proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request")) - config.incrementCounter("ai_sec_request_deny", 1) - if config.protocolOriginal { - proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(denyMessage), -1) - } else if gjson.GetBytes(body, "stream").Bool() { - randomID := generateRandomID() - jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model)) - proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) - } else { - randomID := generateRandomID() - jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, denyMessage)) - proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) - } - ctx.DontReadResponseBody() - } else { - proxywasm.ResumeHttpRequest() - } - } else { - proxywasm.ResumeHttpRequest() - } - }, - ) - if err != nil { - log.Errorf("failed call the safe check service: %v", err) - return types.ActionContinue - } - return types.ActionPause - } else { - log.Debugf("request content is empty. skip") + if len(content) == 0 { + log.Info("request content is empty. skip") return types.ActionContinue } + timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") + randomID, _ := generateHexID(16) + params := map[string]string{ + "Format": "JSON", + "Version": "2022-03-02", + "SignatureMethod": "Hmac-SHA1", + "SignatureNonce": randomID, + "SignatureVersion": "1.0", + "Action": "TextModerationPlus", + "AccessKeyId": config.ak, + "Timestamp": timestamp, + "Service": config.requestCheckService, + "ServiceParameters": fmt.Sprintf(`{"content": "%s"}`, marshalStr(content, log)), + } + if config.token != "" { + params["SecurityToken"] = config.token + } + signature := getSign(params, config.sk+"&") + reqParams := url.Values{} + for k, v := range params { + reqParams.Add(k, v) + } + reqParams.Add("Signature", signature) + err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Info(string(responseBody)) + if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + proxywasm.ResumeHttpRequest() + return + } + var response Response + err := json.Unmarshal(responseBody, &response) + if err != nil { + log.Error("failed to unmarshal aliyun content security response at request phase") + proxywasm.ResumeHttpRequest() + return + } + if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) { + proxywasm.ResumeHttpRequest() + return + } + denyMessage := DefaultDenyMessage + if config.denyMessage != "" { + denyMessage = config.denyMessage + } else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { + denyMessage = response.Data.Advice[0].Answer + } + marshalledDenyMessage := marshalStr(denyMessage, log) + if config.protocolOriginal { + proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) + } else if gjson.GetBytes(body, "stream").Bool() { + randomID := generateRandomID() + jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model)) + proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) + } else { + randomID := generateRandomID() + jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage)) + proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) + } + ctx.DontReadResponseBody() + config.incrementCounter("ai_sec_request_deny", 1) + }, + ) + if err != nil { + log.Errorf("failed call the safe check service: %v", err) + return types.ActionContinue + } + return types.ActionPause } func convertHeaders(hs [][2]string) map[string][]string { @@ -341,92 +378,81 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ if isStreamingResponse { content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath) } else { - content = gjson.GetBytes(body, config.responseContentJsonPath).Raw + content = gjson.GetBytes(body, config.responseContentJsonPath).String() } log.Debugf("Raw response content is: %s", content) - if len(content) > 0 { - timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") - randomID, _ := generateHexID(16) - params := map[string]string{ - "Format": "JSON", - "Version": "2022-03-02", - "SignatureMethod": "Hmac-SHA1", - "SignatureNonce": randomID, - "SignatureVersion": "1.0", - "Action": "TextModerationPlus", - "AccessKeyId": config.ak, - "Timestamp": timestamp, - "Service": config.responseCheckService, - "ServiceParameters": fmt.Sprintf(`{"content": %s}`, content), - } - signature := getSign(params, config.sk+"&") - reqParams := url.Values{} - for k, v := range params { - reqParams.Add(k, v) - } - reqParams.Add("Signature", signature) - err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, - func(statusCode int, responseHeaders http.Header, responseBody []byte) { - defer proxywasm.ResumeHttpResponse() - if statusCode != 200 { - log.Error(string(responseBody)) - return - } - respData := gjson.GetBytes(responseBody, "Data") - if respData.Exists() { - respAdvice := respData.Get("Advice") - respResult := respData.Get("Result") - var denyMessage string - if config.protocolOriginal { - // not openai - if config.denyMessage != "" { - denyMessage = config.denyMessage - } else if respAdvice.Exists() { - denyMessage = respAdvice.Array()[0].Get("Answer").Raw - } else { - denyMessage = DefaultDenyMessage - } - } else { - // openai - if respAdvice.Exists() { - denyMessage = respAdvice.Array()[0].Get("Answer").Raw - } else if config.denyMessage != "" { - denyMessage = config.denyMessage - } else { - denyMessage = DefaultDenyMessage - } - } - if respResult.Array()[0].Get("Label").String() != "nonLabel" { - var jsonData []byte - if config.protocolOriginal { - jsonData = []byte(denyMessage) - } else if isStreamingResponse { - randomID := generateRandomID() - jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model)) - } else { - randomID := generateRandomID() - jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, denyMessage)) - } - delete(hdsMap, "content-length") - hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)} - proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap)) - proxywasm.ReplaceHttpResponseBody(jsonData) - proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String())) - proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("response")) - config.incrementCounter("ai_sec_response_deny", 1) - } - } - }, - ) - if err != nil { - log.Errorf("failed call the safe check service: %v", err) - return types.ActionContinue - } - return types.ActionPause - } else { - log.Debugf("request content is empty. skip") + if len(content) == 0 { + log.Info("response content is empty. skip") return types.ActionContinue } + timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") + randomID, _ := generateHexID(16) + params := map[string]string{ + "Format": "JSON", + "Version": "2022-03-02", + "SignatureMethod": "Hmac-SHA1", + "SignatureNonce": randomID, + "SignatureVersion": "1.0", + "Action": "TextModerationPlus", + "AccessKeyId": config.ak, + "Timestamp": timestamp, + "Service": config.responseCheckService, + "ServiceParameters": fmt.Sprintf(`{"content": "%s"}`, marshalStr(content, log)), + } + if config.token != "" { + params["SecurityToken"] = config.token + } + signature := getSign(params, config.sk+"&") + reqParams := url.Values{} + for k, v := range params { + reqParams.Add(k, v) + } + reqParams.Add("Signature", signature) + err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + defer proxywasm.ResumeHttpResponse() + log.Info(string(responseBody)) + if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + return + } + var response Response + err := json.Unmarshal(responseBody, &response) + if err != nil { + log.Error("failed to unmarshal aliyun content security response at response phase") + return + } + if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) { + return + } + denyMessage := DefaultDenyMessage + if config.denyMessage != "" { + denyMessage = config.denyMessage + } else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { + denyMessage = response.Data.Advice[0].Answer + } + marshalledDenyMessage := marshalStr(denyMessage, log) + var jsonData []byte + if config.protocolOriginal { + jsonData = []byte(marshalledDenyMessage) + } else if isStreamingResponse { + randomID := generateRandomID() + jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model)) + } else { + randomID := generateRandomID() + jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage)) + } + delete(hdsMap, "content-length") + hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)} + proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap)) + proxywasm.ReplaceHttpResponseBody(jsonData) + config.incrementCounter("ai_sec_response_deny", 1) + }, + ) + if err != nil { + log.Errorf("failed call the safe check service: %v", err) + return types.ActionContinue + } + return types.ActionPause } func extractMessageFromStreamingBody(data []byte, jsonPath string) string { @@ -434,10 +460,21 @@ func extractMessageFromStreamingBody(data []byte, jsonPath string) string { strChunks := []string{} for _, chunk := range chunks { // Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}] - jsonRaw := gjson.GetBytes(chunk, jsonPath).Raw - if len(jsonRaw) > 2 { - strChunks = append(strChunks, jsonRaw[1:len(jsonRaw)-1]) - } + strChunks = append(strChunks, gjson.GetBytes(chunk, jsonPath).String()) + } + return strings.Join(strChunks, "") +} + +func marshalStr(raw string, log wrapper.Log) string { + helper := map[string]string{ + "placeholder": raw, + } + marshalledHelper, _ := json.Marshal(helper) + marshalledRaw := gjson.GetBytes(marshalledHelper, "placeholder").Raw + if len(marshalledRaw) >= 2 { + return marshalledRaw[1 : len(marshalledRaw)-1] + } else { + log.Errorf("failed to marshal json string, raw string is: %s", raw) + return "" } - return fmt.Sprintf(`"%s"`, strings.Join(strChunks, "")) }