diff --git a/plugins/wasm-go/extensions/ai-security-guard/README.md b/plugins/wasm-go/extensions/ai-security-guard/README.md index 8b63213d5..122d9e367 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README.md @@ -28,7 +28,21 @@ description: 阿里云内容安全检测 | `responseContentJsonPath` | string | optional | `choices.0.message.content` | 指定要检测内容在响应body中的jsonpath | | `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | 指定要检测内容在流式响应body中的jsonpath | | `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 | -| `denyMessage` | string | optional | openai格式的流失/非流式响应,回答内容为阿里云内容安全的建议回答 | 指定内容非法时的响应内容 | +| `denyMessage` | string | optional | openai格式的流式/非流式响应 | 指定内容非法时的响应内容 | +| `protocol` | string | optional | openai | 协议格式,非openai协议填`original` | + +补充说明一下 `denyMessage`,对于openai格式的请求,对非法请求的处理逻辑为: +- 如果配置了 `denyMessage` + - 优先返回阿里云内容安全的建议回答,格式为openai格式的流式/非流式响应 + - 如果阿里云内容安全未返回建议的回答,返回内容为 `denyMessage` 配置内容,格式为openai格式的流式/非流式响应 +- 如果没有配置 `denyMessage` + - 优先返回阿里云内容安全的建议回答,格式为openai格式的流式/非流式响应 + - 如果阿里云内容安全未返回建议的回答,返回内容为内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,格式为openai格式的流式/非流式响应 + +如果用户使用了非openai格式的协议,应当配置 `denyMessage`,此时对非法请求的处理逻辑为: +- 返回用户配置的 `denyMessage` 内容,用户可以配置其为序列化后的json字符串,以保持与正常请求接口返回格式的一致性 +- 如果 `denyMessage` 为空,优先返回阿里云内容安全的建议回答,格式为纯文本 +- 如果阿里云内容安全未返回建议回答,返回内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,格式为纯文本 ## 配置示例 ### 前提条件 @@ -90,6 +104,7 @@ requestContentJsonPath: "input.prompt" responseContentJsonPath: "output.text" denyCode: 200 denyMessage: "很抱歉,我无法回答您的问题" +protocol: original ``` ## 可观测 diff --git a/plugins/wasm-go/extensions/ai-security-guard/README_EN.md b/plugins/wasm-go/extensions/ai-security-guard/README_EN.md index 450b55417..0367686af 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README_EN.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README_EN.md @@ -29,8 +29,8 @@ Plugin Priority: `300` | `responseContentJsonPath` | string | optional | `choices.0.message.content` | Specify the jsonpath of the content to be detected in the response body | | `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | Specify the jsonpath of the content to be detected in the streaming response body | | `denyCode` | int | optional | 200 | Response status code when the specified content is illegal | -| `denyMessage` | string | optional | Drainage/non-streaming response in openai format, the answer content is the suggested answer from Alibaba Cloud content security - | Response content when the specified content is illegal | +| `denyMessage` | string | optional | Drainage/non-streaming response in openai format, the answer content is the suggested answer from Alibaba Cloud content security | Response content when the specified content is illegal | +| `protocol` | string | optional | openai | protocol format, `openai` or `original` | ## Examples of configuration diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index 4640596c8..de87119d3 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -47,6 +47,7 @@ const ( DefaultResponseJsonPath = "choices.0.message.content" DefaultStreamingResponseJsonPath = "choices.0.delta.content" DefaultDenyCode = 200 + DefaultDenyMessage = "很抱歉,我无法回答您的问题" AliyunUserAgent = "CIPFrom/AIGateway" ) @@ -64,6 +65,7 @@ type AISecurityConfig struct { responseStreamContentJsonPath string denyCode int64 denyMessage string + protocolOriginal bool metrics map[string]proxywasm.MetricCounter } @@ -129,6 +131,7 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) e } config.checkRequest = json.Get("checkRequest").Bool() config.checkResponse = json.Get("checkResponse").Bool() + config.protocolOriginal = json.Get("protocol").String() == "original" config.denyMessage = json.Get("denyMessage").String() if obj := json.Get("denyCode"); obj.Exists() { config.denyCode = obj.Int() @@ -218,29 +221,51 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] reqParams.Add("Signature", signature) err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode != 200 { + log.Error(string(responseBody)) + proxywasm.ResumeHttpRequest() + return + } respData := gjson.GetBytes(responseBody, "Data") if respData.Exists() { respAdvice := respData.Get("Advice") respResult := respData.Get("Result") - if respAdvice.Exists() { + var denyMessage string + if config.protocolOriginal { + // not openai + if config.denyMessage != "" { + denyMessage = config.denyMessage + } else if respAdvice.Exists() { + denyMessage = respAdvice.Array()[0].Get("Answer").Raw + } else { + denyMessage = DefaultDenyMessage + } + } else { + // openai + if respAdvice.Exists() { + denyMessage = respAdvice.Array()[0].Get("Answer").Raw + } else if config.denyMessage != "" { + denyMessage = config.denyMessage + } else { + denyMessage = DefaultDenyMessage + } + } + if respResult.Array()[0].Get("Label").String() != "nonLabel" { proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String())) proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request")) config.incrementCounter("ai_sec_request_deny", 1) - if config.denyMessage != "" { - proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(config.denyMessage), -1) + if config.protocolOriginal { + proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(denyMessage), -1) + } else if gjson.GetBytes(body, "stream").Bool() { + randomID := generateRandomID() + jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model)) + proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) } else { - answer := respAdvice.Array()[0].Get("Answer").Raw - if gjson.GetBytes(body, "stream").Bool() { - randomID := generateRandomID() - jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, answer)) - proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) - } else { - randomID := generateRandomID() - jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, answer)) - proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) - } - ctx.DontReadResponseBody() + randomID := generateRandomID() + jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, denyMessage)) + proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) } + ctx.DontReadResponseBody() } else { proxywasm.ResumeHttpRequest() } @@ -330,22 +355,44 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) { defer proxywasm.ResumeHttpResponse() + if statusCode != 200 { + log.Error(string(responseBody)) + return + } respData := gjson.GetBytes(responseBody, "Data") if respData.Exists() { respAdvice := respData.Get("Advice") respResult := respData.Get("Result") - if respAdvice.Exists() { - var jsonData []byte + var denyMessage string + if config.protocolOriginal { + // not openai if config.denyMessage != "" { - jsonData = []byte(config.denyMessage) + denyMessage = config.denyMessage + } else if respAdvice.Exists() { + denyMessage = respAdvice.Array()[0].Get("Answer").Raw } else { - if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") { - randomID := generateRandomID() - jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String())) - } else { - randomID := generateRandomID() - jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String())) - } + denyMessage = DefaultDenyMessage + } + } else { + // openai + if respAdvice.Exists() { + denyMessage = respAdvice.Array()[0].Get("Answer").Raw + } else if config.denyMessage != "" { + denyMessage = config.denyMessage + } else { + denyMessage = DefaultDenyMessage + } + } + if respResult.Array()[0].Get("Label").String() != "nonLabel" { + var jsonData []byte + if config.protocolOriginal { + jsonData = []byte(denyMessage) + } else if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") { + randomID := generateRandomID() + jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String(), randomID, model)) + } else { + randomID := generateRandomID() + jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String())) } delete(hdsMap, "content-length") hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}