mirror of
https://github.com/alibaba/higress.git
synced 2026-02-25 05:01:19 +08:00
Optimize AI security guard plugin (#1473)
Co-authored-by: Kent Dong <ch3cho@qq.com>
This commit is contained in:
@@ -30,19 +30,23 @@ description: 阿里云内容安全检测
|
||||
| `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 |
|
||||
| `denyMessage` | string | optional | openai格式的流式/非流式响应 | 指定内容非法时的响应内容 |
|
||||
| `protocol` | string | optional | openai | 协议格式,非openai协议填`original` |
|
||||
| `riskLevelBar` | string | optional | high | 拦截风险等级,取值为 max, high, medium, low |
|
||||
|
||||
补充说明一下 `denyMessage`,对于openai格式的请求,对非法请求的处理逻辑为:
|
||||
- 如果配置了 `denyMessage`
|
||||
- 优先返回阿里云内容安全的建议回答,格式为openai格式的流式/非流式响应
|
||||
- 如果阿里云内容安全未返回建议的回答,返回内容为 `denyMessage` 配置内容,格式为openai格式的流式/非流式响应
|
||||
- 如果没有配置 `denyMessage`
|
||||
- 优先返回阿里云内容安全的建议回答,格式为openai格式的流式/非流式响应
|
||||
- 如果阿里云内容安全未返回建议的回答,返回内容为内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,格式为openai格式的流式/非流式响应
|
||||
补充说明一下 `denyMessage`,对非法请求的处理逻辑为:
|
||||
- 如果配置了 `denyMessage`,返回内容为 `denyMessage` 配置内容,格式为openai格式的流式/非流式响应
|
||||
- 如果没有配置 `denyMessage`,优先返回阿里云内容安全的建议回答,格式为openai格式的流式/非流式响应
|
||||
- 如果阿里云内容安全未返回建议的回答,返回内容为内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,格式为openai格式的流式/非流式响应
|
||||
|
||||
如果用户使用了非openai格式的协议,应当配置 `denyMessage`,此时对非法请求的处理逻辑为:
|
||||
- 返回用户配置的 `denyMessage` 内容,用户可以配置其为序列化后的json字符串,以保持与正常请求接口返回格式的一致性
|
||||
- 如果 `denyMessage` 为空,优先返回阿里云内容安全的建议回答,格式为纯文本
|
||||
- 如果阿里云内容安全未返回建议回答,返回内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,格式为纯文本
|
||||
如果用户使用了非openai格式的协议,此时对非法请求的处理逻辑为:
|
||||
- 如果配置了 `denyMessage`,返回用户配置的 `denyMessage` 内容,非流式响应
|
||||
- 如果没有配置 `denyMessage`,优先返回阿里云内容安全的建议回答,非流式响应
|
||||
- 如果阿里云内容安全未返回建议回答,返回内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,非流式响应
|
||||
|
||||
补充说明一下 `riskLevelBar` 的四个等级:
|
||||
- `max`: 检测请求/响应内容,但是不会产生拦截行为
|
||||
- `high`: 内容安全检测结果中风险等级为 `high` 时产生拦截
|
||||
- `medium`: 内容安全检测结果中风险等级 >= `medium` 时产生拦截
|
||||
- `low`: 内容安全检测结果中风险等级 >= `low` 时产生拦截
|
||||
|
||||
## 配置示例
|
||||
### 前提条件
|
||||
|
||||
@@ -7,7 +7,7 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../..
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
|
||||
github.com/tidwall/gjson v1.14.3
|
||||
github.com/tidwall/gjson v1.17.3
|
||||
)
|
||||
|
||||
require (
|
||||
|
||||
@@ -9,8 +9,8 @@ github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
|
||||
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
|
||||
@@ -35,12 +35,16 @@ func main() {
|
||||
}
|
||||
|
||||
const (
|
||||
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":%s,"choices":[{"index":0,"message":{"role":"assistant","content":%s},"logprobs":null,"finish_reason":"stop"}]}`
|
||||
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":%s,"choices":[{"index":0,"delta":{"role":"assistant","content":%s},"logprobs":null,"finish_reason":null}]}`
|
||||
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model": %s,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
|
||||
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
|
||||
MaxRisk = "max"
|
||||
HighRisk = "high"
|
||||
MediumRisk = "medium"
|
||||
LowRisk = "low"
|
||||
NoRisk = "none"
|
||||
|
||||
TracingPrefix = "trace_span_tag."
|
||||
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}]}`
|
||||
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
|
||||
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
|
||||
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
|
||||
|
||||
DefaultRequestCheckService = "llm_query_moderation"
|
||||
DefaultResponseCheckService = "llm_response_moderation"
|
||||
@@ -53,10 +57,37 @@ const (
|
||||
AliyunUserAgent = "CIPFrom/AIGateway"
|
||||
)
|
||||
|
||||
type Response struct {
|
||||
Code int `json:"Code"`
|
||||
Message string `json:"Message"`
|
||||
RequestId string `json:"RequestId"`
|
||||
Data Data `json:"Data"`
|
||||
}
|
||||
|
||||
type Data struct {
|
||||
RiskLevel string `json:"RiskLevel"`
|
||||
Result []Result `json:"Result,omitempty"`
|
||||
Advice []Advice `json:"Advice,omitempty"`
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
RiskWords string `json:"RiskWords,omitempty"`
|
||||
Description string `json:"Description,omitempty"`
|
||||
Confidence float64 `json:"Confidence,omitempty"`
|
||||
Label string `json:"Label,omitempty"`
|
||||
}
|
||||
|
||||
type Advice struct {
|
||||
Answer string `json:"Answer,omitempty"`
|
||||
HitLabel string `json:"HitLabel,omitempty"`
|
||||
HitLibName string `json:"HitLibName,omitempty"`
|
||||
}
|
||||
|
||||
type AISecurityConfig struct {
|
||||
client wrapper.HttpClient
|
||||
ak string
|
||||
sk string
|
||||
token string
|
||||
checkRequest bool
|
||||
requestCheckService string
|
||||
requestContentJsonPath string
|
||||
@@ -67,6 +98,7 @@ type AISecurityConfig struct {
|
||||
denyCode int64
|
||||
denyMessage string
|
||||
protocolOriginal bool
|
||||
riskLevelBar string
|
||||
metrics map[string]proxywasm.MetricCounter
|
||||
}
|
||||
|
||||
@@ -79,12 +111,31 @@ func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64)
|
||||
counter.Increment(inc)
|
||||
}
|
||||
|
||||
func riskLevelToInt(riskLevel string) int {
|
||||
switch riskLevel {
|
||||
case MaxRisk:
|
||||
return 4
|
||||
case HighRisk:
|
||||
return 3
|
||||
case MediumRisk:
|
||||
return 2
|
||||
case LowRisk:
|
||||
return 1
|
||||
case NoRisk:
|
||||
return 0
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
func urlEncoding(rawStr string) string {
|
||||
encodedStr := url.PathEscape(rawStr)
|
||||
encodedStr = strings.ReplaceAll(encodedStr, "+", "%2B")
|
||||
encodedStr = strings.ReplaceAll(encodedStr, ":", "%3A")
|
||||
encodedStr = strings.ReplaceAll(encodedStr, "=", "%3D")
|
||||
encodedStr = strings.ReplaceAll(encodedStr, "&", "%26")
|
||||
encodedStr = strings.ReplaceAll(encodedStr, "$", "%24")
|
||||
encodedStr = strings.ReplaceAll(encodedStr, "@", "%40")
|
||||
return encodedStr
|
||||
}
|
||||
|
||||
@@ -130,6 +181,7 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) e
|
||||
if config.ak == "" || config.sk == "" {
|
||||
return errors.New("invalid AK/SK config")
|
||||
}
|
||||
config.token = json.Get("securityToken").String()
|
||||
config.checkRequest = json.Get("checkRequest").Bool()
|
||||
config.checkResponse = json.Get("checkResponse").Bool()
|
||||
config.protocolOriginal = json.Get("protocol").String() == "original"
|
||||
@@ -164,6 +216,14 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) e
|
||||
} else {
|
||||
config.responseStreamContentJsonPath = DefaultStreamingResponseJsonPath
|
||||
}
|
||||
if obj := json.Get("riskLevelBar"); obj.Exists() {
|
||||
config.riskLevelBar = obj.String()
|
||||
if riskLevelToInt(config.riskLevelBar) <= 0 {
|
||||
return errors.New("invalid risk level, value must be one of [max, high, medium, low]")
|
||||
}
|
||||
} else {
|
||||
config.riskLevelBar = HighRisk
|
||||
}
|
||||
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: serviceName,
|
||||
Port: servicePort,
|
||||
@@ -192,105 +252,82 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
||||
log.Debugf("checking request body...")
|
||||
content := gjson.GetBytes(body, config.requestContentJsonPath).Raw
|
||||
model := gjson.GetBytes(body, "model").Raw
|
||||
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
ctx.SetContext("requestModel", model)
|
||||
log.Debugf("Raw request content is: %s", content)
|
||||
if len(content) > 0 {
|
||||
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
randomID, _ := generateHexID(16)
|
||||
params := map[string]string{
|
||||
"Format": "JSON",
|
||||
"Version": "2022-03-02",
|
||||
"SignatureMethod": "Hmac-SHA1",
|
||||
"SignatureNonce": randomID,
|
||||
"SignatureVersion": "1.0",
|
||||
"Action": "TextModerationPlus",
|
||||
"AccessKeyId": config.ak,
|
||||
"Timestamp": timestamp,
|
||||
"Service": config.requestCheckService,
|
||||
"ServiceParameters": fmt.Sprintf(`{"content": %s}`, content),
|
||||
}
|
||||
signature := getSign(params, config.sk+"&")
|
||||
reqParams := url.Values{}
|
||||
for k, v := range params {
|
||||
reqParams.Add(k, v)
|
||||
}
|
||||
reqParams.Add("Signature", signature)
|
||||
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode != 200 {
|
||||
log.Error(string(responseBody))
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
respData := gjson.GetBytes(responseBody, "Data")
|
||||
if respData.Exists() {
|
||||
respAdvice := respData.Get("Advice")
|
||||
respResult := respData.Get("Result")
|
||||
var denyMessage string
|
||||
messageNeedSerialization := true
|
||||
if config.protocolOriginal {
|
||||
// not openai
|
||||
if config.denyMessage != "" {
|
||||
denyMessage = config.denyMessage
|
||||
} else if respAdvice.Exists() {
|
||||
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
|
||||
messageNeedSerialization = false
|
||||
} else {
|
||||
denyMessage = DefaultDenyMessage
|
||||
}
|
||||
} else {
|
||||
// openai
|
||||
if respAdvice.Exists() {
|
||||
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
|
||||
messageNeedSerialization = false
|
||||
} else if config.denyMessage != "" {
|
||||
denyMessage = config.denyMessage
|
||||
} else {
|
||||
denyMessage = DefaultDenyMessage
|
||||
}
|
||||
}
|
||||
if messageNeedSerialization {
|
||||
if data, err := json.Marshal(denyMessage); err == nil {
|
||||
denyMessage = string(data)
|
||||
} else {
|
||||
denyMessage = fmt.Sprintf("\"%s\"", DefaultDenyMessage)
|
||||
}
|
||||
}
|
||||
if respResult.Array()[0].Get("Label").String() != "nonLabel" {
|
||||
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request"))
|
||||
config.incrementCounter("ai_sec_request_deny", 1)
|
||||
if config.protocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(denyMessage), -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
randomID := generateRandomID()
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model))
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := generateRandomID()
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, denyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
return types.ActionPause
|
||||
} else {
|
||||
log.Debugf("request content is empty. skip")
|
||||
if len(content) == 0 {
|
||||
log.Info("request content is empty. skip")
|
||||
return types.ActionContinue
|
||||
}
|
||||
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
randomID, _ := generateHexID(16)
|
||||
params := map[string]string{
|
||||
"Format": "JSON",
|
||||
"Version": "2022-03-02",
|
||||
"SignatureMethod": "Hmac-SHA1",
|
||||
"SignatureNonce": randomID,
|
||||
"SignatureVersion": "1.0",
|
||||
"Action": "TextModerationPlus",
|
||||
"AccessKeyId": config.ak,
|
||||
"Timestamp": timestamp,
|
||||
"Service": config.requestCheckService,
|
||||
"ServiceParameters": fmt.Sprintf(`{"content": "%s"}`, marshalStr(content, log)),
|
||||
}
|
||||
if config.token != "" {
|
||||
params["SecurityToken"] = config.token
|
||||
}
|
||||
signature := getSign(params, config.sk+"&")
|
||||
reqParams := url.Values{}
|
||||
for k, v := range params {
|
||||
reqParams.Add(k, v)
|
||||
}
|
||||
reqParams.Add("Signature", signature)
|
||||
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
var response Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Error("failed to unmarshal aliyun content security response at request phase")
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
denyMessage := DefaultDenyMessage
|
||||
if config.denyMessage != "" {
|
||||
denyMessage = config.denyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := marshalStr(denyMessage, log)
|
||||
if config.protocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
randomID := generateRandomID()
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := generateRandomID()
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
config.incrementCounter("ai_sec_request_deny", 1)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func convertHeaders(hs [][2]string) map[string][]string {
|
||||
@@ -341,92 +378,81 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
||||
if isStreamingResponse {
|
||||
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
|
||||
} else {
|
||||
content = gjson.GetBytes(body, config.responseContentJsonPath).Raw
|
||||
content = gjson.GetBytes(body, config.responseContentJsonPath).String()
|
||||
}
|
||||
log.Debugf("Raw response content is: %s", content)
|
||||
if len(content) > 0 {
|
||||
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
randomID, _ := generateHexID(16)
|
||||
params := map[string]string{
|
||||
"Format": "JSON",
|
||||
"Version": "2022-03-02",
|
||||
"SignatureMethod": "Hmac-SHA1",
|
||||
"SignatureNonce": randomID,
|
||||
"SignatureVersion": "1.0",
|
||||
"Action": "TextModerationPlus",
|
||||
"AccessKeyId": config.ak,
|
||||
"Timestamp": timestamp,
|
||||
"Service": config.responseCheckService,
|
||||
"ServiceParameters": fmt.Sprintf(`{"content": %s}`, content),
|
||||
}
|
||||
signature := getSign(params, config.sk+"&")
|
||||
reqParams := url.Values{}
|
||||
for k, v := range params {
|
||||
reqParams.Add(k, v)
|
||||
}
|
||||
reqParams.Add("Signature", signature)
|
||||
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
defer proxywasm.ResumeHttpResponse()
|
||||
if statusCode != 200 {
|
||||
log.Error(string(responseBody))
|
||||
return
|
||||
}
|
||||
respData := gjson.GetBytes(responseBody, "Data")
|
||||
if respData.Exists() {
|
||||
respAdvice := respData.Get("Advice")
|
||||
respResult := respData.Get("Result")
|
||||
var denyMessage string
|
||||
if config.protocolOriginal {
|
||||
// not openai
|
||||
if config.denyMessage != "" {
|
||||
denyMessage = config.denyMessage
|
||||
} else if respAdvice.Exists() {
|
||||
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
|
||||
} else {
|
||||
denyMessage = DefaultDenyMessage
|
||||
}
|
||||
} else {
|
||||
// openai
|
||||
if respAdvice.Exists() {
|
||||
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
|
||||
} else if config.denyMessage != "" {
|
||||
denyMessage = config.denyMessage
|
||||
} else {
|
||||
denyMessage = DefaultDenyMessage
|
||||
}
|
||||
}
|
||||
if respResult.Array()[0].Get("Label").String() != "nonLabel" {
|
||||
var jsonData []byte
|
||||
if config.protocolOriginal {
|
||||
jsonData = []byte(denyMessage)
|
||||
} else if isStreamingResponse {
|
||||
randomID := generateRandomID()
|
||||
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model))
|
||||
} else {
|
||||
randomID := generateRandomID()
|
||||
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, denyMessage))
|
||||
}
|
||||
delete(hdsMap, "content-length")
|
||||
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
|
||||
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
|
||||
proxywasm.ReplaceHttpResponseBody(jsonData)
|
||||
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("response"))
|
||||
config.incrementCounter("ai_sec_response_deny", 1)
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
return types.ActionPause
|
||||
} else {
|
||||
log.Debugf("request content is empty. skip")
|
||||
if len(content) == 0 {
|
||||
log.Info("response content is empty. skip")
|
||||
return types.ActionContinue
|
||||
}
|
||||
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
randomID, _ := generateHexID(16)
|
||||
params := map[string]string{
|
||||
"Format": "JSON",
|
||||
"Version": "2022-03-02",
|
||||
"SignatureMethod": "Hmac-SHA1",
|
||||
"SignatureNonce": randomID,
|
||||
"SignatureVersion": "1.0",
|
||||
"Action": "TextModerationPlus",
|
||||
"AccessKeyId": config.ak,
|
||||
"Timestamp": timestamp,
|
||||
"Service": config.responseCheckService,
|
||||
"ServiceParameters": fmt.Sprintf(`{"content": "%s"}`, marshalStr(content, log)),
|
||||
}
|
||||
if config.token != "" {
|
||||
params["SecurityToken"] = config.token
|
||||
}
|
||||
signature := getSign(params, config.sk+"&")
|
||||
reqParams := url.Values{}
|
||||
for k, v := range params {
|
||||
reqParams.Add(k, v)
|
||||
}
|
||||
reqParams.Add("Signature", signature)
|
||||
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
defer proxywasm.ResumeHttpResponse()
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
return
|
||||
}
|
||||
var response Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Error("failed to unmarshal aliyun content security response at response phase")
|
||||
return
|
||||
}
|
||||
if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) {
|
||||
return
|
||||
}
|
||||
denyMessage := DefaultDenyMessage
|
||||
if config.denyMessage != "" {
|
||||
denyMessage = config.denyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := marshalStr(denyMessage, log)
|
||||
var jsonData []byte
|
||||
if config.protocolOriginal {
|
||||
jsonData = []byte(marshalledDenyMessage)
|
||||
} else if isStreamingResponse {
|
||||
randomID := generateRandomID()
|
||||
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
|
||||
} else {
|
||||
randomID := generateRandomID()
|
||||
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
|
||||
}
|
||||
delete(hdsMap, "content-length")
|
||||
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
|
||||
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
|
||||
proxywasm.ReplaceHttpResponseBody(jsonData)
|
||||
config.incrementCounter("ai_sec_response_deny", 1)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
|
||||
@@ -434,10 +460,21 @@ func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
|
||||
strChunks := []string{}
|
||||
for _, chunk := range chunks {
|
||||
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
|
||||
jsonRaw := gjson.GetBytes(chunk, jsonPath).Raw
|
||||
if len(jsonRaw) > 2 {
|
||||
strChunks = append(strChunks, jsonRaw[1:len(jsonRaw)-1])
|
||||
}
|
||||
strChunks = append(strChunks, gjson.GetBytes(chunk, jsonPath).String())
|
||||
}
|
||||
return strings.Join(strChunks, "")
|
||||
}
|
||||
|
||||
func marshalStr(raw string, log wrapper.Log) string {
|
||||
helper := map[string]string{
|
||||
"placeholder": raw,
|
||||
}
|
||||
marshalledHelper, _ := json.Marshal(helper)
|
||||
marshalledRaw := gjson.GetBytes(marshalledHelper, "placeholder").Raw
|
||||
if len(marshalledRaw) >= 2 {
|
||||
return marshalledRaw[1 : len(marshalledRaw)-1]
|
||||
} else {
|
||||
log.Errorf("failed to marshal json string, raw string is: %s", raw)
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(`"%s"`, strings.Join(strChunks, ""))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user