diff --git a/plugins/wasm-go/extensions/ai-security-guard/README.md b/plugins/wasm-go/extensions/ai-security-guard/README.md index 86abec29e..385e0dd14 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README.md @@ -20,6 +20,7 @@ description: 阿里云内容安全检测 | `serviceHost` | string | requried | - | 阿里云内容安全endpoint的域名 | | `accessKey` | string | requried | - | 阿里云AK | | `secretKey` | string | requried | - | 阿里云SK | +| `action` | string | requried | - | 阿里云ai安全业务接口 | | `checkRequest` | bool | optional | false | 检查提问内容是否合规 | | `checkResponse` | bool | optional | false | 检查大模型的回答内容是否合规,生效时会使流式响应变为非流式 | | `requestCheckService` | string | optional | llm_query_moderation | 指定阿里云内容安全用于检测输入内容的服务 | @@ -30,7 +31,9 @@ description: 阿里云内容安全检测 | `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 | | `denyMessage` | string | optional | openai格式的流式/非流式响应 | 指定内容非法时的响应内容 | | `protocol` | string | optional | openai | 协议格式,非openai协议填`original` | -| `riskLevelBar` | string | optional | high | 拦截风险等级,取值为 max, high, medium, low | +| `contentModerationLevelBar` | string | optional | max | 内容合规检测拦截风险等级,取值为 `max`, `high`, `medium` or `low` | +| `promptAttackLevelBar` | string | optional | max | 提示词攻击检测拦截风险等级,取值为 `max`, `high`, `medium` or `low` | +| `sensitiveDataLevelBar` | string | optional | S4 | 敏感内容检测拦截风险等级,取值为 `S4`, `S3`, `S2` or `S1` | | `timeout` | int | optional | 2000 | 调用内容安全服务时的超时时间 | | `bufferLimit` | int | optional | 1000 | 调用内容安全服务时每段文本的长度限制 | @@ -44,11 +47,19 @@ description: 阿里云内容安全检测 - 如果没有配置 `denyMessage`,优先返回阿里云内容安全的建议回答,非流式响应 - 如果阿里云内容安全未返回建议回答,返回内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,非流式响应 -补充说明一下 `riskLevelBar` 的四个等级: -- `max`: 检测请求/响应内容,但是不会产生拦截行为 -- `high`: 内容安全检测结果中风险等级为 `high` 时产生拦截 -- `medium`: 内容安全检测结果中风险等级 >= `medium` 时产生拦截 -- `low`: 内容安全检测结果中风险等级 >= `low` 时产生拦截 +补充说明一下内容合规检测、提示词攻击检测、敏感内容检测三种风险的四个等级: + +- 对于内容合规检测、提示词攻击检测: + - `max`: 检测请求/响应内容,但是不会产生拦截行为 + - `high`: 内容安全检测/提示词攻击检测 结果中风险等级为 `high` 时产生拦截 + - `medium`: 内容安全检测/提示词攻击检测 结果中风险等级 >= `medium` 时产生拦截 + - `low`: 内容安全检测/提示词攻击检测 结果中风险等级 >= `low` 时产生拦截 + +- 对于敏感内容检测: + - `S4`: 检测请求/响应内容,但是不会产生拦截行为 + - `S3`: 敏感内容检测结果中风险等级为 `S3` 时产生拦截 + - `S2`: 敏感内容检测结果中风险等级 >= `S2` 时产生拦截 + - `S1`: 敏感内容检测结果中风险等级 >= `S1` 时产生拦截 ## 配置示例 ### 前提条件 @@ -143,21 +154,21 @@ curl http://localhost/v1/chat/completions \ ```json { - "id": "chatcmpl-AAy3hK1dE4ODaegbGOMoC9VY4Sizv", - "object": "chat.completion", - "created": 1677652288, - "model": "gpt-4o-mini", - "system_fingerprint": "fp_44709d6fcb", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "作为一名人工智能助手,我不能提供涉及色情、暴力、政治等敏感话题的内容。如果您有其他相关问题,欢迎您提问。", - }, - "logprobs": null, - "finish_reason": "stop" - } - ] + "id": "chatcmpl-AAy3hK1dE4ODaegbGOMoC9VY4Sizv", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4o-mini", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "作为一名人工智能助手,我不能提供涉及色情、暴力、政治等敏感话题的内容。如果您有其他相关问题,欢迎您提问。", + }, + "logprobs": null, + "finish_reason": "stop" + } + ] } ``` diff --git a/plugins/wasm-go/extensions/ai-security-guard/README_EN.md b/plugins/wasm-go/extensions/ai-security-guard/README_EN.md index 1dd96fa81..360893070 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README_EN.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README_EN.md @@ -21,6 +21,7 @@ Plugin Priority: `300` | `serviceHost` | string | requried | - | Host of Aliyun content security service endpoint | | `accessKey` | string | requried | - | Aliyun accesskey | | `secretKey` | string | requried | - | Aliyun secretkey | +| `action` | string | requried | - | Aliyun ai guardrails business interface | | `checkRequest` | bool | optional | false | check if the input is legal | | `checkResponse` | bool | optional | false | check if the output is legal | | `requestCheckService` | string | optional | llm_query_moderation | Aliyun yundun service name for input check | @@ -31,7 +32,9 @@ Plugin Priority: `300` | `denyCode` | int | optional | 200 | Response status code when the specified content is illegal | | `denyMessage` | string | optional | Drainage/non-streaming response in openai format, the answer content is the suggested answer from Alibaba Cloud content security | Response content when the specified content is illegal | | `protocol` | string | optional | openai | protocol format, `openai` or `original` | -| `riskLevelBar` | string | optional | high | risk level threshold, `max`, `high`, `medium` or `low` | +| `contentModerationLevelBar` | string | optional | max | contentModeration risk level threshold, `max`, `high`, `medium` or `low` | +| `promptAttackLevelBar` | string | optional | max | promptAttack risk level threshold, `max`, `high`, `medium` or `low` | +| `sensitiveDataLevelBar` | string | optional | S4 | sensitiveData risk level threshold, `S4`, `S3`, `S2` or `S1` | | `timeout` | int | optional | 2000 | timeout for lvwang service | | `bufferLimit` | int | optional | 1000 | Limit the length of each text when calling the lvwang service | diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index a2005d8ed..9273c68f9 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -45,6 +45,16 @@ const ( LowRisk = "low" NoRisk = "none" + S4Sensitive = "S4" + S3Sensitive = "S3" + S2Sensitive = "S2" + S1Sensitive = "S1" + NoSensitive = "S0" + + ContentModerationType = "contentModeration" + PromptAttackType = "promptAttack" + SensitiveDataType = "sensitiveData" + OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}` OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` @@ -74,6 +84,7 @@ type Data struct { RiskLevel string `json:"RiskLevel"` Result []Result `json:"Result,omitempty"` Advice []Advice `json:"Advice,omitempty"` + Detail []Detail `json:"Detail,omitempty"` } type Result struct { @@ -89,11 +100,18 @@ type Advice struct { HitLibName string `json:"HitLibName,omitempty"` } +type Detail struct { + Suggestion string `json:"Suggestion,omitempty"` + Type string `json:"Type,omitempty"` + Level string `json:"Level,omitempty"` +} + type AISecurityConfig struct { client wrapper.HttpClient ak string sk string token string + action string checkRequest bool requestCheckService string requestContentJsonPath string @@ -104,7 +122,9 @@ type AISecurityConfig struct { denyCode int64 denyMessage string protocolOriginal bool - riskLevelBar string + contentModerationLevelBar string + promptAttackLevelBar string + sensitiveDataLevelBar string timeout uint32 bufferLimit int metrics map[string]proxywasm.MetricCounter @@ -121,23 +141,47 @@ func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) counter.Increment(inc) } -func riskLevelToInt(riskLevel string) int { +func levelToInt(riskLevel string) int { switch riskLevel { - case MaxRisk: + case MaxRisk, S4Sensitive: return 4 - case HighRisk: + case HighRisk, S3Sensitive: return 3 - case MediumRisk: + case MediumRisk, S2Sensitive: return 2 - case LowRisk: + case LowRisk, S1Sensitive: return 1 - case NoRisk: + case NoRisk, NoSensitive: return 0 default: return -1 } } +func isRiskLevelAcceptable(action string, data Data, config AISecurityConfig) bool { + if action == "MultiModalGuard" { + for _, detail := range data.Detail { + switch detail.Type { + case ContentModerationType: + if levelToInt(detail.Level) >= levelToInt(config.contentModerationLevelBar) { + return false + } + case PromptAttackType: + if levelToInt(detail.Level) >= levelToInt(config.promptAttackLevelBar) { + return false + } + case SensitiveDataType: + if levelToInt(detail.Level) >= levelToInt(config.sensitiveDataLevelBar) { + return false + } + } + } + return true + } else { + return levelToInt(data.RiskLevel) < levelToInt(config.contentModerationLevelBar) + } +} + func urlEncoding(rawStr string) string { encodedStr := url.PathEscape(rawStr) encodedStr = strings.ReplaceAll(encodedStr, "+", "%2B") @@ -192,6 +236,7 @@ func parseConfig(json gjson.Result, config *AISecurityConfig) error { return errors.New("invalid AK/SK config") } config.token = json.Get("securityToken").String() + config.action = json.Get("action").String() config.checkRequest = json.Get("checkRequest").Bool() config.checkResponse = json.Get("checkResponse").Bool() config.protocolOriginal = json.Get("protocol").String() == "original" @@ -226,13 +271,29 @@ func parseConfig(json gjson.Result, config *AISecurityConfig) error { } else { config.responseStreamContentJsonPath = DefaultStreamingResponseJsonPath } - if obj := json.Get("riskLevelBar"); obj.Exists() { - config.riskLevelBar = obj.String() - if riskLevelToInt(config.riskLevelBar) <= 0 { - return errors.New("invalid risk level, value must be one of [max, high, medium, low]") + if obj := json.Get("contentModerationLevelBar"); obj.Exists() { + config.contentModerationLevelBar = obj.String() + if levelToInt(config.contentModerationLevelBar) <= 0 { + return errors.New("invalid contentModerationLevelBar, value must be one of [max, high, medium, low]") } } else { - config.riskLevelBar = HighRisk + config.contentModerationLevelBar = MaxRisk + } + if obj := json.Get("promptAttackLevelBar"); obj.Exists() { + config.promptAttackLevelBar = obj.String() + if levelToInt(config.promptAttackLevelBar) <= 0 { + return errors.New("invalid promptAttackLevelBar, value must be one of [max, high, medium, low]") + } + } else { + config.promptAttackLevelBar = MaxRisk + } + if obj := json.Get("sensitiveDataLevelBar"); obj.Exists() { + config.sensitiveDataLevelBar = obj.String() + if levelToInt(config.sensitiveDataLevelBar) <= 0 { + return errors.New("invalid sensitiveDataLevelBar, value must be one of [S4, S3, S2, S1]") + } + } else { + config.sensitiveDataLevelBar = S4Sensitive } if obj := json.Get("timeout"); obj.Exists() { config.timeout = uint32(obj.Int()) @@ -306,7 +367,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] proxywasm.ResumeHttpRequest() return } - if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) { + if isRiskLevelAcceptable(config.action, response.Data, config) { if contentIndex >= len(content) { endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) @@ -370,11 +431,11 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] "SignatureMethod": "Hmac-SHA1", "SignatureNonce": randomID, "SignatureVersion": "1.0", - "Action": "TextModerationPlus", + "Action": config.action, "AccessKeyId": config.ak, "Timestamp": timestamp, "Service": checkService, - "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece)), + "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, marshalStr(contentPiece), AliyunUserAgent), } if config.token != "" { params["SecurityToken"] = config.token @@ -444,7 +505,7 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi ctx.SetContext("during_call", false) return } - if riskLevelToInt(response.Data.RiskLevel) >= riskLevelToInt(config.riskLevelBar) { + if !isRiskLevelAcceptable(config.action, response.Data, config) { denyMessage := DefaultDenyMessage if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { denyMessage = "\n" + response.Data.Advice[0].Answer @@ -495,11 +556,11 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi "SignatureMethod": "Hmac-SHA1", "SignatureNonce": randomID, "SignatureVersion": "1.0", - "Action": "TextModerationPlus", + "Action": config.action, "AccessKeyId": config.ak, "Timestamp": timestamp, "Service": checkService, - "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, ctx.GetContext("sessionID").(string), marshalStr(buffer)), + "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, ctx.GetContext("sessionID").(string), marshalStr(buffer), AliyunUserAgent), } if config.token != "" { params["SecurityToken"] = config.token @@ -563,7 +624,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ proxywasm.ResumeHttpResponse() return } - if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) { + if isRiskLevelAcceptable(config.action, response.Data, config) { if contentIndex >= len(content) { endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) @@ -626,11 +687,11 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ "SignatureMethod": "Hmac-SHA1", "SignatureNonce": randomID, "SignatureVersion": "1.0", - "Action": "TextModerationPlus", + "Action": config.action, "AccessKeyId": config.ak, "Timestamp": timestamp, "Service": checkService, - "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece)), + "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, marshalStr(contentPiece), AliyunUserAgent), } if config.token != "" { params["SecurityToken"] = config.token