mirror of
https://github.com/alibaba/higress.git
synced 2026-02-28 14:40:50 +08:00
Update ai security guard (#1261)
This commit is contained in:
@@ -1,22 +1,143 @@
|
||||
---
|
||||
title: AI内容安全
|
||||
keywords: [higress, AI, security]
|
||||
description: 阿里云内容安全检测
|
||||
---
|
||||
|
||||
## 功能说明
|
||||
通过对接阿里云内容安全检测大模型的输入输出,保障AI应用内容合法合规。
|
||||
|
||||
## 运行属性
|
||||
|
||||
插件执行阶段:`默认阶段`
|
||||
插件执行优先级:`300`
|
||||
|
||||
## 配置说明
|
||||
| Name | Type | Requirement | Default | Description |
|
||||
| :-: | :-: | :-: | :-: | :-: |
|
||||
| serviceSource | string | requried | - | 服务来源,填dns |
|
||||
| serviceName | string | requried | - | 服务名 |
|
||||
| servicePort | string | requried | - | 服务端口 |
|
||||
| domain | string | requried | - | 阿里云内容安全endpoint |
|
||||
| ak | string | requried | - | 阿里云AK |
|
||||
| sk | string | requried | - | 阿里云SK |
|
||||
| ------------ | ------------ | ------------ | ------------ | ------------ |
|
||||
| `serviceName` | string | requried | - | 服务名 |
|
||||
| `servicePort` | string | requried | - | 服务端口 |
|
||||
| `serviceHost` | string | requried | - | 阿里云内容安全endpoint的域名 |
|
||||
| `accessKey` | string | requried | - | 阿里云AK |
|
||||
| `secretKey` | string | requried | - | 阿里云SK |
|
||||
| `checkRequest` | bool | optional | false | 检查提问内容是否合规 |
|
||||
| `checkResponse` | bool | optional | false | 检查大模型的回答内容是否合规,生效时会使流式响应变为非流式 |
|
||||
| `requestCheckService` | string | optional | llm_query_moderation | 指定阿里云内容安全用于检测输入内容的服务 |
|
||||
| `responseCheckService` | string | optional | llm_response_moderation | 指定阿里云内容安全用于检测输出内容的服务 |
|
||||
| `requestContentJsonPath` | string | optional | `messages.@reverse.0.content` | 指定要检测内容在请求body中的jsonpath |
|
||||
| `responseContentJsonPath` | string | optional | `choices.0.message.content` | 指定要检测内容在响应body中的jsonpath |
|
||||
| `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | 指定要检测内容在流式响应body中的jsonpath |
|
||||
| `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 |
|
||||
| `denyMessage` | string | optional | openai格式的流失/非流式响应,回答内容为阿里云内容安全的建议回答 | 指定内容非法时的响应内容 |
|
||||
|
||||
|
||||
## 配置示例
|
||||
### 前提条件
|
||||
由于插件中需要调用阿里云内容安全服务,所以需要先创建一个DNS类型的服务,例如:
|
||||
|
||||

|
||||
|
||||
### 检测输入内容是否合规
|
||||
|
||||
```yaml
|
||||
serviceSource: "dns"
|
||||
serviceName: "safecheck"
|
||||
serviceName: safecheck.dns
|
||||
servicePort: 443
|
||||
domain: "green-cip.cn-shanghai.aliyuncs.com"
|
||||
ak: "XXXXXXXXX"
|
||||
sk: "XXXXXXXXXXXXXXX"
|
||||
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
|
||||
accessKey: "XXXXXXXXX"
|
||||
secretKey: "XXXXXXXXXXXXXXX"
|
||||
checkRequest: true
|
||||
```
|
||||
|
||||
### 检测输入与输出是否合规
|
||||
|
||||
```yaml
|
||||
serviceName: safecheck.dns
|
||||
servicePort: 443
|
||||
serviceHost: green-cip.cn-shanghai.aliyuncs.com
|
||||
accessKey: "XXXXXXXXX"
|
||||
secretKey: "XXXXXXXXXXXXXXX"
|
||||
checkRequest: true
|
||||
checkResponse: true
|
||||
```
|
||||
|
||||
### 指定自定义内容安全检测服务
|
||||
用户可能需要根据不同的场景配置不同的检测规则,该问题可通过为不同域名/路由/服务配置不同的内容安全检测服务实现。如下图所示,我们创建了一个名为 llm_query_moderation_01 的检测服务,其中的检测规则在 llm_query_moderation 之上做了一些改动:
|
||||
|
||||

|
||||
|
||||
接下来在目标域名/路由/服务级别进行以下配置,指定使用我们自定义的 llm_query_moderation_01 中的规则进行检测:
|
||||
|
||||
```yaml
|
||||
serviceName: safecheck.dns
|
||||
servicePort: 443
|
||||
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
|
||||
accessKey: "XXXXXXXXX"
|
||||
secretKey: "XXXXXXXXXXXXXXX"
|
||||
checkRequest: true
|
||||
requestCheckService: llm_query_moderation_01
|
||||
```
|
||||
|
||||
### 配置非openai协议(例如百炼App)
|
||||
|
||||
```yaml
|
||||
serviceName: safecheck.dns
|
||||
servicePort: 443
|
||||
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
|
||||
accessKey: "XXXXXXXXX"
|
||||
secretKey: "XXXXXXXXXXXXXXX"
|
||||
checkRequest: true
|
||||
checkResponse: true
|
||||
requestContentJsonPath: "input.prompt"
|
||||
responseContentJsonPath: "output.text"
|
||||
denyCode: 200
|
||||
denyMessage: "很抱歉,我无法回答您的问题"
|
||||
```
|
||||
|
||||
## 可观测
|
||||
### Metric
|
||||
ai-security-guard 插件提供了以下监控指标:
|
||||
- `ai_sec_request_deny`: 请求内容安全检测失败请求数
|
||||
- `ai_sec_response_deny`: 模型回答安全检测失败请求数
|
||||
|
||||
### Trace
|
||||
如果开启了链路追踪,ai-security-guard 插件会在请求 span 中添加以下 attributes:
|
||||
- `ai_sec_risklabel`: 表示请求命中的风险类型
|
||||
- `ai_sec_deny_phase`: 表示请求被检测到风险的阶段(取值为request或者response)
|
||||
|
||||
## 请求示例
|
||||
```bash
|
||||
curl http://localhost/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "这是一段非法内容"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
请求内容会被发送到阿里云内容安全服务进行检测,如果请求内容检测结果为非法,网关将返回形如以下的回答:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677652288,
|
||||
"model": "gpt-4o-mini",
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "作为一名人工智能助手,我不能提供涉及色情、暴力、政治等敏感话题的内容。如果您有其他相关问题,欢迎您提问。",
|
||||
},
|
||||
"logprobs": null,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
69
plugins/wasm-go/extensions/ai-security-guard/README_EN.md
Normal file
69
plugins/wasm-go/extensions/ai-security-guard/README_EN.md
Normal file
@@ -0,0 +1,69 @@
|
||||
---
|
||||
title: AI Content Security
|
||||
keywords: [higress, AI, security]
|
||||
description: Alibaba Cloud content security
|
||||
---
|
||||
|
||||
|
||||
## Introduction
|
||||
Integrate with Aliyun content security service for detections of input and output of LLMs, ensuring that application content is legal and compliant.
|
||||
|
||||
## Runtime Properties
|
||||
|
||||
Plugin Phase: `CUSTOM`
|
||||
Plugin Priority: `300`
|
||||
|
||||
## Configuration
|
||||
| Name | Type | Requirement | Default | Description |
|
||||
| ------------ | ------------ | ------------ | ------------ | ------------ |
|
||||
| `serviceName` | string | requried | - | service name |
|
||||
| `servicePort` | string | requried | - | service port |
|
||||
| `serviceHost` | string | requried | - | Host of Aliyun content security service endpoint |
|
||||
| `accessKey` | string | requried | - | Aliyun accesskey |
|
||||
| `secretKey` | string | requried | - | Aliyun secretkey |
|
||||
| `checkRequest` | bool | optional | false | check if the input is legal |
|
||||
| `checkResponse` | bool | optional | false | check if the output is legal |
|
||||
| `requestCheckService` | string | optional | llm_query_moderation | Aliyun yundun service name for input check |
|
||||
| `responseCheckService` | string | optional | llm_response_moderation | Aliyun yundun service name for output check |
|
||||
| `requestContentJsonPath` | string | optional | `messages.@reverse.0.content` | Specify the jsonpath of the content to be detected in the request body |
|
||||
| `responseContentJsonPath` | string | optional | `choices.0.message.content` | Specify the jsonpath of the content to be detected in the response body |
|
||||
| `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | Specify the jsonpath of the content to be detected in the streaming response body |
|
||||
| `denyCode` | int | optional | 200 | Response status code when the specified content is illegal |
|
||||
| `denyMessage` | string | optional | Drainage/non-streaming response in openai format, the answer content is the suggested answer from Alibaba Cloud content security
|
||||
| Response content when the specified content is illegal |
|
||||
|
||||
|
||||
## Examples of configuration
|
||||
### Check if the input is legal
|
||||
|
||||
```yaml
|
||||
serviceName: safecheck.dns
|
||||
servicePort: 443
|
||||
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
|
||||
accessKey: "XXXXXXXXX"
|
||||
secretKey: "XXXXXXXXXXXXXXX"
|
||||
checkRequest: true
|
||||
```
|
||||
|
||||
### Check if both the input and output are legal
|
||||
|
||||
```yaml
|
||||
serviceName: safecheck.dns
|
||||
servicePort: 443
|
||||
serviceHost: green-cip.cn-shanghai.aliyuncs.com
|
||||
accessKey: "XXXXXXXXX"
|
||||
secretKey: "XXXXXXXXXXXXXXX"
|
||||
checkRequest: true
|
||||
checkResponse: true
|
||||
```
|
||||
|
||||
## Observability
|
||||
### Metric
|
||||
ai-security-guard plugin provides following metrics:
|
||||
- `ai_sec_request_deny`: count of requests denied at request phase
|
||||
- `ai_sec_response_deny`: count of requests denied at response phase
|
||||
|
||||
### Trace
|
||||
ai-security-guard plugin provides following span attributes:
|
||||
- `ai_sec_risklabel`: risk type of this request
|
||||
- `ai_sec_deny_phase`: denied phase of this request, value can be request/response
|
||||
@@ -1,4 +1,4 @@
|
||||
module myplugin
|
||||
module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard
|
||||
|
||||
go 1.18
|
||||
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906 h1:RhEmB+ApLKsClZD7joTC4ifmsVgOVz4pFLdPR3xhNaE=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906/go.mod h1:10jQXKsYFUF7djs+Oy7t82f4dbie9pISfP9FJwpPLuk=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -32,16 +32,47 @@ func main() {
|
||||
)
|
||||
}
|
||||
|
||||
const (
|
||||
OpenAIResponseFormat = `{"id": "chatcmpl-123","object": "chat.completion","model": "gpt-4o-mini","choices": [{"index": 0,"message": {"role": "assistant","content": "%s"},"logprobs": null,"finish_reason": "stop"}]}`
|
||||
OpenAIStreamResponseChunk = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4o-mini", "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
|
||||
OpenAIStreamResponseEnd = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4o-mini", "choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
|
||||
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd
|
||||
|
||||
TracingPrefix = "trace_span_tag."
|
||||
|
||||
DefaultRequestCheckService = "llm_query_moderation"
|
||||
DefaultResponseCheckService = "llm_response_moderation"
|
||||
DefaultRequestJsonPath = "messages.@reverse.0.content"
|
||||
DefaultResponseJsonPath = "choices.0.message.content"
|
||||
DefaultStreamingResponseJsonPath = "choices.0.delta.content"
|
||||
DefaultDenyCode = 200
|
||||
|
||||
AliyunUserAgent = "CIPFrom/AIGateway"
|
||||
)
|
||||
|
||||
type AISecurityConfig struct {
|
||||
client wrapper.HttpClient
|
||||
ak string
|
||||
sk string
|
||||
client wrapper.HttpClient
|
||||
ak string
|
||||
sk string
|
||||
checkRequest bool
|
||||
requestCheckService string
|
||||
requestContentJsonPath string
|
||||
checkResponse bool
|
||||
responseCheckService string
|
||||
responseContentJsonPath string
|
||||
responseStreamContentJsonPath string
|
||||
denyCode int64
|
||||
denyMessage string
|
||||
metrics map[string]proxywasm.MetricCounter
|
||||
}
|
||||
|
||||
type StandardResponse struct {
|
||||
Code int `json:"Code"`
|
||||
Phase string `json:"BlockPhase"`
|
||||
Message string `json:"Message"`
|
||||
func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) {
|
||||
counter, ok := config.metrics[metricName]
|
||||
if !ok {
|
||||
counter = proxywasm.DefineCounterMetric(metricName)
|
||||
config.metrics[metricName] = counter
|
||||
}
|
||||
counter.Increment(inc)
|
||||
}
|
||||
|
||||
func urlEncoding(rawStr string) string {
|
||||
@@ -71,7 +102,7 @@ func getSign(params map[string]string, secret string) string {
|
||||
})
|
||||
canonicalStr := strings.Join(paramArray, "&")
|
||||
signStr := "POST&%2F&" + urlEncoding(canonicalStr)
|
||||
fmt.Println(signStr)
|
||||
// proxywasm.LogInfo(signStr)
|
||||
return hmacSha1(signStr, secret)
|
||||
}
|
||||
|
||||
@@ -86,32 +117,70 @@ func generateHexID(length int) (string, error) {
|
||||
func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) error {
|
||||
serviceName := json.Get("serviceName").String()
|
||||
servicePort := json.Get("servicePort").Int()
|
||||
domain := json.Get("domain").String()
|
||||
config.ak = json.Get("ak").String()
|
||||
config.sk = json.Get("sk").String()
|
||||
if serviceName == "" || servicePort == 0 || domain == "" {
|
||||
serviceHost := json.Get("serviceHost").String()
|
||||
if serviceName == "" || servicePort == 0 || serviceHost == "" {
|
||||
return errors.New("invalid service config")
|
||||
}
|
||||
config.client = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
ServiceName: serviceName,
|
||||
Port: servicePort,
|
||||
Domain: domain,
|
||||
config.ak = json.Get("accessKey").String()
|
||||
config.sk = json.Get("secretKey").String()
|
||||
if config.ak == "" || config.sk == "" {
|
||||
return errors.New("invalid AK/SK config")
|
||||
}
|
||||
config.checkRequest = json.Get("checkRequest").Bool()
|
||||
config.checkResponse = json.Get("checkResponse").Bool()
|
||||
config.denyMessage = json.Get("denyMessage").String()
|
||||
if obj := json.Get("denyCode"); obj.Exists() {
|
||||
config.denyCode = obj.Int()
|
||||
} else {
|
||||
config.denyCode = DefaultDenyCode
|
||||
}
|
||||
if obj := json.Get("requestCheckService"); obj.Exists() {
|
||||
config.requestCheckService = obj.String()
|
||||
} else {
|
||||
config.requestCheckService = DefaultRequestCheckService
|
||||
}
|
||||
if obj := json.Get("responseCheckService"); obj.Exists() {
|
||||
config.responseCheckService = obj.String()
|
||||
} else {
|
||||
config.responseCheckService = DefaultResponseCheckService
|
||||
}
|
||||
if obj := json.Get("requestContentJsonPath"); obj.Exists() {
|
||||
config.requestContentJsonPath = obj.String()
|
||||
} else {
|
||||
config.requestContentJsonPath = DefaultRequestJsonPath
|
||||
}
|
||||
if obj := json.Get("responseContentJsonPath"); obj.Exists() {
|
||||
config.responseContentJsonPath = obj.String()
|
||||
} else {
|
||||
config.responseContentJsonPath = DefaultResponseJsonPath
|
||||
}
|
||||
if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() {
|
||||
config.responseStreamContentJsonPath = obj.String()
|
||||
} else {
|
||||
config.responseStreamContentJsonPath = DefaultStreamingResponseJsonPath
|
||||
}
|
||||
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: serviceName,
|
||||
Port: servicePort,
|
||||
Host: serviceHost,
|
||||
})
|
||||
config.metrics = make(map[string]proxywasm.MetricCounter)
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
|
||||
if !config.checkRequest {
|
||||
ctx.DontReadRequestBody()
|
||||
}
|
||||
if !config.checkResponse {
|
||||
ctx.DontReadResponseBody()
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
||||
messages := gjson.GetBytes(body, "messages").Array()
|
||||
if len(messages) > 0 {
|
||||
role := messages[len(messages)-1].Get("role").String()
|
||||
content := messages[len(messages)-1].Get("content").String()
|
||||
if role != "user" {
|
||||
return types.ActionContinue
|
||||
}
|
||||
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
|
||||
if content != "" {
|
||||
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
randomID, _ := generateHexID(16)
|
||||
params := map[string]string{
|
||||
@@ -123,7 +192,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
||||
"Action": "TextModerationPlus",
|
||||
"AccessKeyId": config.ak,
|
||||
"Timestamp": timestamp,
|
||||
"Service": "llm_query_moderation",
|
||||
"Service": config.requestCheckService,
|
||||
"ServiceParameters": `{"content": "` + content + `"}`,
|
||||
}
|
||||
signature := getSign(params, config.sk+"&")
|
||||
@@ -132,31 +201,27 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
||||
reqParams.Add(k, v)
|
||||
}
|
||||
reqParams.Add("Signature", signature)
|
||||
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil,
|
||||
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
respData := gjson.GetBytes(responseBody, "Data")
|
||||
if respData.Exists() {
|
||||
respAdvice := respData.Get("Advice")
|
||||
respResult := respData.Get("Result")
|
||||
if respAdvice.Exists() {
|
||||
sr := StandardResponse{
|
||||
Code: 403,
|
||||
Phase: "Request",
|
||||
Message: respAdvice.Array()[0].Get("Answer").String(),
|
||||
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request"))
|
||||
config.incrementCounter("ai_sec_request_deny", 1)
|
||||
if config.denyMessage != "" {
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(config.denyMessage), -1)
|
||||
} else {
|
||||
if gjson.GetBytes(body, "stream").Bool() {
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
}
|
||||
jsonData, _ := json.MarshalIndent(sr, "", " ")
|
||||
label := respResult.Array()[0].Get("Label").String()
|
||||
proxywasm.SetProperty([]string{"risklabel"}, []byte(label))
|
||||
proxywasm.SendHttpResponseWithDetail(403, "ai-security-guard.label."+label, [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
} else if respResult.Array()[0].Get("Label").String() != "nonLabel" {
|
||||
sr := StandardResponse{
|
||||
Code: 403,
|
||||
Phase: "Request",
|
||||
Message: "risk detected",
|
||||
}
|
||||
jsonData, _ := json.MarshalIndent(sr, "", " ")
|
||||
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||
proxywasm.SendHttpResponseWithDetail(403, "ai-security-guard.risk_detected", [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
@@ -206,9 +271,16 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
||||
messages := gjson.GetBytes(body, "choices").Array()
|
||||
if len(messages) > 0 {
|
||||
content := messages[0].Get("message").Get("content").String()
|
||||
hdsMap := ctx.GetContext("headers").(map[string][]string)
|
||||
isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream")
|
||||
var content string
|
||||
if isStreamingResponse {
|
||||
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
|
||||
} else {
|
||||
content = gjson.GetBytes(body, config.responseContentJsonPath).String()
|
||||
}
|
||||
log.Debugf("Raw response content is: %s", content)
|
||||
if len(content) > 0 {
|
||||
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
randomID, _ := generateHexID(16)
|
||||
params := map[string]string{
|
||||
@@ -220,7 +292,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
||||
"Action": "TextModerationPlus",
|
||||
"AccessKeyId": config.ak,
|
||||
"Timestamp": timestamp,
|
||||
"Service": "llm_response_moderation",
|
||||
"Service": config.responseCheckService,
|
||||
"ServiceParameters": `{"content": "` + content + `"}`,
|
||||
}
|
||||
signature := getSign(params, config.sk+"&")
|
||||
@@ -229,7 +301,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
||||
reqParams.Add(k, v)
|
||||
}
|
||||
reqParams.Add("Signature", signature)
|
||||
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil,
|
||||
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
defer proxywasm.ResumeHttpResponse()
|
||||
respData := gjson.GetBytes(responseBody, "Data")
|
||||
@@ -237,31 +309,23 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
||||
respAdvice := respData.Get("Advice")
|
||||
respResult := respData.Get("Result")
|
||||
if respAdvice.Exists() {
|
||||
sr := StandardResponse{
|
||||
Code: 403,
|
||||
Phase: "Response",
|
||||
Message: respAdvice.Array()[0].Get("Answer").String(),
|
||||
var jsonData []byte
|
||||
if config.denyMessage != "" {
|
||||
jsonData = []byte(config.denyMessage)
|
||||
} else {
|
||||
if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
|
||||
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
|
||||
} else {
|
||||
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
|
||||
}
|
||||
}
|
||||
jsonData, _ := json.MarshalIndent(sr, "", " ")
|
||||
hdsMap := ctx.GetContext("headers").(map[string][]string)
|
||||
delete(hdsMap, "content-length")
|
||||
hdsMap[":status"] = []string{"403"}
|
||||
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
|
||||
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
|
||||
proxywasm.ReplaceHttpResponseBody(jsonData)
|
||||
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||
} else if respResult.Array()[0].Get("Label").String() != "nonLabel" {
|
||||
sr := StandardResponse{
|
||||
Code: 403,
|
||||
Phase: "Response",
|
||||
Message: "risk detected",
|
||||
}
|
||||
jsonData, _ := json.MarshalIndent(sr, "", " ")
|
||||
hdsMap := ctx.GetContext("headers").(map[string][]string)
|
||||
delete(hdsMap, "content-length")
|
||||
hdsMap[":status"] = []string{"403"}
|
||||
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
|
||||
proxywasm.ReplaceHttpResponseBody(jsonData)
|
||||
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("response"))
|
||||
config.incrementCounter("ai_sec_response_deny", 1)
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -271,3 +335,16 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
|
||||
func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
|
||||
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
|
||||
strChunks := []string{}
|
||||
for _, chunk := range chunks {
|
||||
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
|
||||
jsonObj := gjson.GetBytes(chunk, jsonPath)
|
||||
if jsonObj.Exists() {
|
||||
strChunks = append(strChunks, jsonObj.String())
|
||||
}
|
||||
}
|
||||
return strings.Join(strChunks, "")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user