mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 12:47:28 +08:00
Update ai security guard (#1261)
This commit is contained in:
@@ -1,22 +1,143 @@
|
|||||||
|
---
|
||||||
|
title: AI内容安全
|
||||||
|
keywords: [higress, AI, security]
|
||||||
|
description: 阿里云内容安全检测
|
||||||
|
---
|
||||||
|
|
||||||
## 功能说明
|
## 功能说明
|
||||||
|
通过对接阿里云内容安全检测大模型的输入输出,保障AI应用内容合法合规。
|
||||||
|
|
||||||
|
## 运行属性
|
||||||
|
|
||||||
|
插件执行阶段:`默认阶段`
|
||||||
|
插件执行优先级:`300`
|
||||||
|
|
||||||
## 配置说明
|
## 配置说明
|
||||||
| Name | Type | Requirement | Default | Description |
|
| Name | Type | Requirement | Default | Description |
|
||||||
| :-: | :-: | :-: | :-: | :-: |
|
| ------------ | ------------ | ------------ | ------------ | ------------ |
|
||||||
| serviceSource | string | requried | - | 服务来源,填dns |
|
| `serviceName` | string | requried | - | 服务名 |
|
||||||
| serviceName | string | requried | - | 服务名 |
|
| `servicePort` | string | requried | - | 服务端口 |
|
||||||
| servicePort | string | requried | - | 服务端口 |
|
| `serviceHost` | string | requried | - | 阿里云内容安全endpoint的域名 |
|
||||||
| domain | string | requried | - | 阿里云内容安全endpoint |
|
| `accessKey` | string | requried | - | 阿里云AK |
|
||||||
| ak | string | requried | - | 阿里云AK |
|
| `secretKey` | string | requried | - | 阿里云SK |
|
||||||
| sk | string | requried | - | 阿里云SK |
|
| `checkRequest` | bool | optional | false | 检查提问内容是否合规 |
|
||||||
|
| `checkResponse` | bool | optional | false | 检查大模型的回答内容是否合规,生效时会使流式响应变为非流式 |
|
||||||
|
| `requestCheckService` | string | optional | llm_query_moderation | 指定阿里云内容安全用于检测输入内容的服务 |
|
||||||
|
| `responseCheckService` | string | optional | llm_response_moderation | 指定阿里云内容安全用于检测输出内容的服务 |
|
||||||
|
| `requestContentJsonPath` | string | optional | `messages.@reverse.0.content` | 指定要检测内容在请求body中的jsonpath |
|
||||||
|
| `responseContentJsonPath` | string | optional | `choices.0.message.content` | 指定要检测内容在响应body中的jsonpath |
|
||||||
|
| `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | 指定要检测内容在流式响应body中的jsonpath |
|
||||||
|
| `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 |
|
||||||
|
| `denyMessage` | string | optional | openai格式的流失/非流式响应,回答内容为阿里云内容安全的建议回答 | 指定内容非法时的响应内容 |
|
||||||
|
|
||||||
|
|
||||||
## 配置示例
|
## 配置示例
|
||||||
|
### 前提条件
|
||||||
|
由于插件中需要调用阿里云内容安全服务,所以需要先创建一个DNS类型的服务,例如:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### 检测输入内容是否合规
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
serviceSource: "dns"
|
serviceName: safecheck.dns
|
||||||
serviceName: "safecheck"
|
|
||||||
servicePort: 443
|
servicePort: 443
|
||||||
domain: "green-cip.cn-shanghai.aliyuncs.com"
|
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
|
||||||
ak: "XXXXXXXXX"
|
accessKey: "XXXXXXXXX"
|
||||||
sk: "XXXXXXXXXXXXXXX"
|
secretKey: "XXXXXXXXXXXXXXX"
|
||||||
|
checkRequest: true
|
||||||
|
```
|
||||||
|
|
||||||
|
### 检测输入与输出是否合规
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
serviceName: safecheck.dns
|
||||||
|
servicePort: 443
|
||||||
|
serviceHost: green-cip.cn-shanghai.aliyuncs.com
|
||||||
|
accessKey: "XXXXXXXXX"
|
||||||
|
secretKey: "XXXXXXXXXXXXXXX"
|
||||||
|
checkRequest: true
|
||||||
|
checkResponse: true
|
||||||
|
```
|
||||||
|
|
||||||
|
### 指定自定义内容安全检测服务
|
||||||
|
用户可能需要根据不同的场景配置不同的检测规则,该问题可通过为不同域名/路由/服务配置不同的内容安全检测服务实现。如下图所示,我们创建了一个名为 llm_query_moderation_01 的检测服务,其中的检测规则在 llm_query_moderation 之上做了一些改动:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
接下来在目标域名/路由/服务级别进行以下配置,指定使用我们自定义的 llm_query_moderation_01 中的规则进行检测:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
serviceName: safecheck.dns
|
||||||
|
servicePort: 443
|
||||||
|
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
|
||||||
|
accessKey: "XXXXXXXXX"
|
||||||
|
secretKey: "XXXXXXXXXXXXXXX"
|
||||||
|
checkRequest: true
|
||||||
|
requestCheckService: llm_query_moderation_01
|
||||||
|
```
|
||||||
|
|
||||||
|
### 配置非openai协议(例如百炼App)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
serviceName: safecheck.dns
|
||||||
|
servicePort: 443
|
||||||
|
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
|
||||||
|
accessKey: "XXXXXXXXX"
|
||||||
|
secretKey: "XXXXXXXXXXXXXXX"
|
||||||
|
checkRequest: true
|
||||||
|
checkResponse: true
|
||||||
|
requestContentJsonPath: "input.prompt"
|
||||||
|
responseContentJsonPath: "output.text"
|
||||||
|
denyCode: 200
|
||||||
|
denyMessage: "很抱歉,我无法回答您的问题"
|
||||||
|
```
|
||||||
|
|
||||||
|
## 可观测
|
||||||
|
### Metric
|
||||||
|
ai-security-guard 插件提供了以下监控指标:
|
||||||
|
- `ai_sec_request_deny`: 请求内容安全检测失败请求数
|
||||||
|
- `ai_sec_response_deny`: 模型回答安全检测失败请求数
|
||||||
|
|
||||||
|
### Trace
|
||||||
|
如果开启了链路追踪,ai-security-guard 插件会在请求 span 中添加以下 attributes:
|
||||||
|
- `ai_sec_risklabel`: 表示请求命中的风险类型
|
||||||
|
- `ai_sec_deny_phase`: 表示请求被检测到风险的阶段(取值为request或者response)
|
||||||
|
|
||||||
|
## 请求示例
|
||||||
|
```bash
|
||||||
|
curl http://localhost/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-4o-mini",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "这是一段非法内容"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
请求内容会被发送到阿里云内容安全服务进行检测,如果请求内容检测结果为非法,网关将返回形如以下的回答:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1677652288,
|
||||||
|
"model": "gpt-4o-mini",
|
||||||
|
"system_fingerprint": "fp_44709d6fcb",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "作为一名人工智能助手,我不能提供涉及色情、暴力、政治等敏感话题的内容。如果您有其他相关问题,欢迎您提问。",
|
||||||
|
},
|
||||||
|
"logprobs": null,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
```
|
```
|
||||||
|
|||||||
69
plugins/wasm-go/extensions/ai-security-guard/README_EN.md
Normal file
69
plugins/wasm-go/extensions/ai-security-guard/README_EN.md
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
---
|
||||||
|
title: AI Content Security
|
||||||
|
keywords: [higress, AI, security]
|
||||||
|
description: Alibaba Cloud content security
|
||||||
|
---
|
||||||
|
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
Integrate with Aliyun content security service for detections of input and output of LLMs, ensuring that application content is legal and compliant.
|
||||||
|
|
||||||
|
## Runtime Properties
|
||||||
|
|
||||||
|
Plugin Phase: `CUSTOM`
|
||||||
|
Plugin Priority: `300`
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
| Name | Type | Requirement | Default | Description |
|
||||||
|
| ------------ | ------------ | ------------ | ------------ | ------------ |
|
||||||
|
| `serviceName` | string | requried | - | service name |
|
||||||
|
| `servicePort` | string | requried | - | service port |
|
||||||
|
| `serviceHost` | string | requried | - | Host of Aliyun content security service endpoint |
|
||||||
|
| `accessKey` | string | requried | - | Aliyun accesskey |
|
||||||
|
| `secretKey` | string | requried | - | Aliyun secretkey |
|
||||||
|
| `checkRequest` | bool | optional | false | check if the input is legal |
|
||||||
|
| `checkResponse` | bool | optional | false | check if the output is legal |
|
||||||
|
| `requestCheckService` | string | optional | llm_query_moderation | Aliyun yundun service name for input check |
|
||||||
|
| `responseCheckService` | string | optional | llm_response_moderation | Aliyun yundun service name for output check |
|
||||||
|
| `requestContentJsonPath` | string | optional | `messages.@reverse.0.content` | Specify the jsonpath of the content to be detected in the request body |
|
||||||
|
| `responseContentJsonPath` | string | optional | `choices.0.message.content` | Specify the jsonpath of the content to be detected in the response body |
|
||||||
|
| `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | Specify the jsonpath of the content to be detected in the streaming response body |
|
||||||
|
| `denyCode` | int | optional | 200 | Response status code when the specified content is illegal |
|
||||||
|
| `denyMessage` | string | optional | Drainage/non-streaming response in openai format, the answer content is the suggested answer from Alibaba Cloud content security
|
||||||
|
| Response content when the specified content is illegal |
|
||||||
|
|
||||||
|
|
||||||
|
## Examples of configuration
|
||||||
|
### Check if the input is legal
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
serviceName: safecheck.dns
|
||||||
|
servicePort: 443
|
||||||
|
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
|
||||||
|
accessKey: "XXXXXXXXX"
|
||||||
|
secretKey: "XXXXXXXXXXXXXXX"
|
||||||
|
checkRequest: true
|
||||||
|
```
|
||||||
|
|
||||||
|
### Check if both the input and output are legal
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
serviceName: safecheck.dns
|
||||||
|
servicePort: 443
|
||||||
|
serviceHost: green-cip.cn-shanghai.aliyuncs.com
|
||||||
|
accessKey: "XXXXXXXXX"
|
||||||
|
secretKey: "XXXXXXXXXXXXXXX"
|
||||||
|
checkRequest: true
|
||||||
|
checkResponse: true
|
||||||
|
```
|
||||||
|
|
||||||
|
## Observability
|
||||||
|
### Metric
|
||||||
|
ai-security-guard plugin provides following metrics:
|
||||||
|
- `ai_sec_request_deny`: count of requests denied at request phase
|
||||||
|
- `ai_sec_response_deny`: count of requests denied at response phase
|
||||||
|
|
||||||
|
### Trace
|
||||||
|
ai-security-guard plugin provides following span attributes:
|
||||||
|
- `ai_sec_risklabel`: risk type of this request
|
||||||
|
- `ai_sec_deny_phase`: denied phase of this request, value can be request/response
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
module myplugin
|
module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard
|
||||||
|
|
||||||
go 1.18
|
go 1.18
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,9 @@
|
|||||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
|
|
||||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
|
|
||||||
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906 h1:RhEmB+ApLKsClZD7joTC4ifmsVgOVz4pFLdPR3xhNaE=
|
|
||||||
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906/go.mod h1:10jQXKsYFUF7djs+Oy7t82f4dbie9pISfP9FJwpPLuk=
|
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
|
||||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
|
||||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -32,16 +32,47 @@ func main() {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
OpenAIResponseFormat = `{"id": "chatcmpl-123","object": "chat.completion","model": "gpt-4o-mini","choices": [{"index": 0,"message": {"role": "assistant","content": "%s"},"logprobs": null,"finish_reason": "stop"}]}`
|
||||||
|
OpenAIStreamResponseChunk = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4o-mini", "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
|
||||||
|
OpenAIStreamResponseEnd = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4o-mini", "choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
|
||||||
|
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd
|
||||||
|
|
||||||
|
TracingPrefix = "trace_span_tag."
|
||||||
|
|
||||||
|
DefaultRequestCheckService = "llm_query_moderation"
|
||||||
|
DefaultResponseCheckService = "llm_response_moderation"
|
||||||
|
DefaultRequestJsonPath = "messages.@reverse.0.content"
|
||||||
|
DefaultResponseJsonPath = "choices.0.message.content"
|
||||||
|
DefaultStreamingResponseJsonPath = "choices.0.delta.content"
|
||||||
|
DefaultDenyCode = 200
|
||||||
|
|
||||||
|
AliyunUserAgent = "CIPFrom/AIGateway"
|
||||||
|
)
|
||||||
|
|
||||||
type AISecurityConfig struct {
|
type AISecurityConfig struct {
|
||||||
client wrapper.HttpClient
|
client wrapper.HttpClient
|
||||||
ak string
|
ak string
|
||||||
sk string
|
sk string
|
||||||
|
checkRequest bool
|
||||||
|
requestCheckService string
|
||||||
|
requestContentJsonPath string
|
||||||
|
checkResponse bool
|
||||||
|
responseCheckService string
|
||||||
|
responseContentJsonPath string
|
||||||
|
responseStreamContentJsonPath string
|
||||||
|
denyCode int64
|
||||||
|
denyMessage string
|
||||||
|
metrics map[string]proxywasm.MetricCounter
|
||||||
}
|
}
|
||||||
|
|
||||||
type StandardResponse struct {
|
func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) {
|
||||||
Code int `json:"Code"`
|
counter, ok := config.metrics[metricName]
|
||||||
Phase string `json:"BlockPhase"`
|
if !ok {
|
||||||
Message string `json:"Message"`
|
counter = proxywasm.DefineCounterMetric(metricName)
|
||||||
|
config.metrics[metricName] = counter
|
||||||
|
}
|
||||||
|
counter.Increment(inc)
|
||||||
}
|
}
|
||||||
|
|
||||||
func urlEncoding(rawStr string) string {
|
func urlEncoding(rawStr string) string {
|
||||||
@@ -71,7 +102,7 @@ func getSign(params map[string]string, secret string) string {
|
|||||||
})
|
})
|
||||||
canonicalStr := strings.Join(paramArray, "&")
|
canonicalStr := strings.Join(paramArray, "&")
|
||||||
signStr := "POST&%2F&" + urlEncoding(canonicalStr)
|
signStr := "POST&%2F&" + urlEncoding(canonicalStr)
|
||||||
fmt.Println(signStr)
|
// proxywasm.LogInfo(signStr)
|
||||||
return hmacSha1(signStr, secret)
|
return hmacSha1(signStr, secret)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,32 +117,70 @@ func generateHexID(length int) (string, error) {
|
|||||||
func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) error {
|
func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) error {
|
||||||
serviceName := json.Get("serviceName").String()
|
serviceName := json.Get("serviceName").String()
|
||||||
servicePort := json.Get("servicePort").Int()
|
servicePort := json.Get("servicePort").Int()
|
||||||
domain := json.Get("domain").String()
|
serviceHost := json.Get("serviceHost").String()
|
||||||
config.ak = json.Get("ak").String()
|
if serviceName == "" || servicePort == 0 || serviceHost == "" {
|
||||||
config.sk = json.Get("sk").String()
|
|
||||||
if serviceName == "" || servicePort == 0 || domain == "" {
|
|
||||||
return errors.New("invalid service config")
|
return errors.New("invalid service config")
|
||||||
}
|
}
|
||||||
config.client = wrapper.NewClusterClient(wrapper.DnsCluster{
|
config.ak = json.Get("accessKey").String()
|
||||||
ServiceName: serviceName,
|
config.sk = json.Get("secretKey").String()
|
||||||
Port: servicePort,
|
if config.ak == "" || config.sk == "" {
|
||||||
Domain: domain,
|
return errors.New("invalid AK/SK config")
|
||||||
|
}
|
||||||
|
config.checkRequest = json.Get("checkRequest").Bool()
|
||||||
|
config.checkResponse = json.Get("checkResponse").Bool()
|
||||||
|
config.denyMessage = json.Get("denyMessage").String()
|
||||||
|
if obj := json.Get("denyCode"); obj.Exists() {
|
||||||
|
config.denyCode = obj.Int()
|
||||||
|
} else {
|
||||||
|
config.denyCode = DefaultDenyCode
|
||||||
|
}
|
||||||
|
if obj := json.Get("requestCheckService"); obj.Exists() {
|
||||||
|
config.requestCheckService = obj.String()
|
||||||
|
} else {
|
||||||
|
config.requestCheckService = DefaultRequestCheckService
|
||||||
|
}
|
||||||
|
if obj := json.Get("responseCheckService"); obj.Exists() {
|
||||||
|
config.responseCheckService = obj.String()
|
||||||
|
} else {
|
||||||
|
config.responseCheckService = DefaultResponseCheckService
|
||||||
|
}
|
||||||
|
if obj := json.Get("requestContentJsonPath"); obj.Exists() {
|
||||||
|
config.requestContentJsonPath = obj.String()
|
||||||
|
} else {
|
||||||
|
config.requestContentJsonPath = DefaultRequestJsonPath
|
||||||
|
}
|
||||||
|
if obj := json.Get("responseContentJsonPath"); obj.Exists() {
|
||||||
|
config.responseContentJsonPath = obj.String()
|
||||||
|
} else {
|
||||||
|
config.responseContentJsonPath = DefaultResponseJsonPath
|
||||||
|
}
|
||||||
|
if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() {
|
||||||
|
config.responseStreamContentJsonPath = obj.String()
|
||||||
|
} else {
|
||||||
|
config.responseStreamContentJsonPath = DefaultStreamingResponseJsonPath
|
||||||
|
}
|
||||||
|
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||||
|
FQDN: serviceName,
|
||||||
|
Port: servicePort,
|
||||||
|
Host: serviceHost,
|
||||||
})
|
})
|
||||||
|
config.metrics = make(map[string]proxywasm.MetricCounter)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
|
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
|
||||||
|
if !config.checkRequest {
|
||||||
|
ctx.DontReadRequestBody()
|
||||||
|
}
|
||||||
|
if !config.checkResponse {
|
||||||
|
ctx.DontReadResponseBody()
|
||||||
|
}
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
|
|
||||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
||||||
messages := gjson.GetBytes(body, "messages").Array()
|
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
|
||||||
if len(messages) > 0 {
|
if content != "" {
|
||||||
role := messages[len(messages)-1].Get("role").String()
|
|
||||||
content := messages[len(messages)-1].Get("content").String()
|
|
||||||
if role != "user" {
|
|
||||||
return types.ActionContinue
|
|
||||||
}
|
|
||||||
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||||
randomID, _ := generateHexID(16)
|
randomID, _ := generateHexID(16)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
@@ -123,7 +192,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
|||||||
"Action": "TextModerationPlus",
|
"Action": "TextModerationPlus",
|
||||||
"AccessKeyId": config.ak,
|
"AccessKeyId": config.ak,
|
||||||
"Timestamp": timestamp,
|
"Timestamp": timestamp,
|
||||||
"Service": "llm_query_moderation",
|
"Service": config.requestCheckService,
|
||||||
"ServiceParameters": `{"content": "` + content + `"}`,
|
"ServiceParameters": `{"content": "` + content + `"}`,
|
||||||
}
|
}
|
||||||
signature := getSign(params, config.sk+"&")
|
signature := getSign(params, config.sk+"&")
|
||||||
@@ -132,31 +201,27 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
|||||||
reqParams.Add(k, v)
|
reqParams.Add(k, v)
|
||||||
}
|
}
|
||||||
reqParams.Add("Signature", signature)
|
reqParams.Add("Signature", signature)
|
||||||
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil,
|
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
|
||||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
respData := gjson.GetBytes(responseBody, "Data")
|
respData := gjson.GetBytes(responseBody, "Data")
|
||||||
if respData.Exists() {
|
if respData.Exists() {
|
||||||
respAdvice := respData.Get("Advice")
|
respAdvice := respData.Get("Advice")
|
||||||
respResult := respData.Get("Result")
|
respResult := respData.Get("Result")
|
||||||
if respAdvice.Exists() {
|
if respAdvice.Exists() {
|
||||||
sr := StandardResponse{
|
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||||
Code: 403,
|
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request"))
|
||||||
Phase: "Request",
|
config.incrementCounter("ai_sec_request_deny", 1)
|
||||||
Message: respAdvice.Array()[0].Get("Answer").String(),
|
if config.denyMessage != "" {
|
||||||
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(config.denyMessage), -1)
|
||||||
|
} else {
|
||||||
|
if gjson.GetBytes(body, "stream").Bool() {
|
||||||
|
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
|
||||||
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||||
|
} else {
|
||||||
|
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
|
||||||
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
jsonData, _ := json.MarshalIndent(sr, "", " ")
|
|
||||||
label := respResult.Array()[0].Get("Label").String()
|
|
||||||
proxywasm.SetProperty([]string{"risklabel"}, []byte(label))
|
|
||||||
proxywasm.SendHttpResponseWithDetail(403, "ai-security-guard.label."+label, [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
|
||||||
} else if respResult.Array()[0].Get("Label").String() != "nonLabel" {
|
|
||||||
sr := StandardResponse{
|
|
||||||
Code: 403,
|
|
||||||
Phase: "Request",
|
|
||||||
Message: "risk detected",
|
|
||||||
}
|
|
||||||
jsonData, _ := json.MarshalIndent(sr, "", " ")
|
|
||||||
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
|
||||||
proxywasm.SendHttpResponseWithDetail(403, "ai-security-guard.risk_detected", [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
|
||||||
} else {
|
} else {
|
||||||
proxywasm.ResumeHttpRequest()
|
proxywasm.ResumeHttpRequest()
|
||||||
}
|
}
|
||||||
@@ -206,9 +271,16 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
|
|||||||
}
|
}
|
||||||
|
|
||||||
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
||||||
messages := gjson.GetBytes(body, "choices").Array()
|
hdsMap := ctx.GetContext("headers").(map[string][]string)
|
||||||
if len(messages) > 0 {
|
isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream")
|
||||||
content := messages[0].Get("message").Get("content").String()
|
var content string
|
||||||
|
if isStreamingResponse {
|
||||||
|
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
|
||||||
|
} else {
|
||||||
|
content = gjson.GetBytes(body, config.responseContentJsonPath).String()
|
||||||
|
}
|
||||||
|
log.Debugf("Raw response content is: %s", content)
|
||||||
|
if len(content) > 0 {
|
||||||
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||||
randomID, _ := generateHexID(16)
|
randomID, _ := generateHexID(16)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
@@ -220,7 +292,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
|||||||
"Action": "TextModerationPlus",
|
"Action": "TextModerationPlus",
|
||||||
"AccessKeyId": config.ak,
|
"AccessKeyId": config.ak,
|
||||||
"Timestamp": timestamp,
|
"Timestamp": timestamp,
|
||||||
"Service": "llm_response_moderation",
|
"Service": config.responseCheckService,
|
||||||
"ServiceParameters": `{"content": "` + content + `"}`,
|
"ServiceParameters": `{"content": "` + content + `"}`,
|
||||||
}
|
}
|
||||||
signature := getSign(params, config.sk+"&")
|
signature := getSign(params, config.sk+"&")
|
||||||
@@ -229,7 +301,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
|||||||
reqParams.Add(k, v)
|
reqParams.Add(k, v)
|
||||||
}
|
}
|
||||||
reqParams.Add("Signature", signature)
|
reqParams.Add("Signature", signature)
|
||||||
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil,
|
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
|
||||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
defer proxywasm.ResumeHttpResponse()
|
defer proxywasm.ResumeHttpResponse()
|
||||||
respData := gjson.GetBytes(responseBody, "Data")
|
respData := gjson.GetBytes(responseBody, "Data")
|
||||||
@@ -237,31 +309,23 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
|||||||
respAdvice := respData.Get("Advice")
|
respAdvice := respData.Get("Advice")
|
||||||
respResult := respData.Get("Result")
|
respResult := respData.Get("Result")
|
||||||
if respAdvice.Exists() {
|
if respAdvice.Exists() {
|
||||||
sr := StandardResponse{
|
var jsonData []byte
|
||||||
Code: 403,
|
if config.denyMessage != "" {
|
||||||
Phase: "Response",
|
jsonData = []byte(config.denyMessage)
|
||||||
Message: respAdvice.Array()[0].Get("Answer").String(),
|
} else {
|
||||||
|
if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
|
||||||
|
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
|
||||||
|
} else {
|
||||||
|
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
jsonData, _ := json.MarshalIndent(sr, "", " ")
|
|
||||||
hdsMap := ctx.GetContext("headers").(map[string][]string)
|
|
||||||
delete(hdsMap, "content-length")
|
delete(hdsMap, "content-length")
|
||||||
hdsMap[":status"] = []string{"403"}
|
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
|
||||||
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
|
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
|
||||||
proxywasm.ReplaceHttpResponseBody(jsonData)
|
proxywasm.ReplaceHttpResponseBody(jsonData)
|
||||||
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||||
} else if respResult.Array()[0].Get("Label").String() != "nonLabel" {
|
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("response"))
|
||||||
sr := StandardResponse{
|
config.incrementCounter("ai_sec_response_deny", 1)
|
||||||
Code: 403,
|
|
||||||
Phase: "Response",
|
|
||||||
Message: "risk detected",
|
|
||||||
}
|
|
||||||
jsonData, _ := json.MarshalIndent(sr, "", " ")
|
|
||||||
hdsMap := ctx.GetContext("headers").(map[string][]string)
|
|
||||||
delete(hdsMap, "content-length")
|
|
||||||
hdsMap[":status"] = []string{"403"}
|
|
||||||
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
|
|
||||||
proxywasm.ReplaceHttpResponseBody(jsonData)
|
|
||||||
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -271,3 +335,16 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
|||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
|
||||||
|
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
|
||||||
|
strChunks := []string{}
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
|
||||||
|
jsonObj := gjson.GetBytes(chunk, jsonPath)
|
||||||
|
if jsonObj.Exists() {
|
||||||
|
strChunks = append(strChunks, jsonObj.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(strChunks, "")
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user