Update ai security guard (#1261)

This commit is contained in:
rinfx
2024-09-24 19:42:34 +08:00
committed by GitHub
parent b82853c653
commit e004321cb0
5 changed files with 350 additions and 88 deletions

View File

@@ -1,22 +1,143 @@
---
title: AI内容安全
keywords: [higress, AI, security]
description: 阿里云内容安全检测
---
## 功能说明
通过对接阿里云内容安全检测大模型的输入输出保障AI应用内容合法合规。
## 运行属性
插件执行阶段:`默认阶段`
插件执行优先级:`300`
## 配置说明
| Name | Type | Requirement | Default | Description |
| :-: | :-: | :-: | :-: | :-: |
| serviceSource | string | requried | - | 服务来源填dns |
| serviceName | string | requried | - | 服务 |
| servicePort | string | requried | - | 服务端口 |
| domain | string | requried | - | 阿里云内容安全endpoint |
| ak | string | requried | - | 阿里云AK |
| sk | string | requried | - | 阿里云SK |
| ------------ | ------------ | ------------ | ------------ | ------------ |
| `serviceName` | string | requried | - | 服务 |
| `servicePort` | string | requried | - | 服务端口 |
| `serviceHost` | string | requried | - | 阿里云内容安全endpoint的域名 |
| `accessKey` | string | requried | - | 阿里云AK |
| `secretKey` | string | requried | - | 阿里云SK |
| `checkRequest` | bool | optional | false | 检查提问内容是否合规 |
| `checkResponse` | bool | optional | false | 检查大模型的回答内容是否合规,生效时会使流式响应变为非流式 |
| `requestCheckService` | string | optional | llm_query_moderation | 指定阿里云内容安全用于检测输入内容的服务 |
| `responseCheckService` | string | optional | llm_response_moderation | 指定阿里云内容安全用于检测输出内容的服务 |
| `requestContentJsonPath` | string | optional | `messages.@reverse.0.content` | 指定要检测内容在请求body中的jsonpath |
| `responseContentJsonPath` | string | optional | `choices.0.message.content` | 指定要检测内容在响应body中的jsonpath |
| `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | 指定要检测内容在流式响应body中的jsonpath |
| `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 |
| `denyMessage` | string | optional | openai格式的流失/非流式响应,回答内容为阿里云内容安全的建议回答 | 指定内容非法时的响应内容 |
## 配置示例
### 前提条件
由于插件中需要调用阿里云内容安全服务所以需要先创建一个DNS类型的服务例如
![](https://img.alicdn.com/imgextra/i4/O1CN013AbDcn1slCY19inU2_!!6000000005806-0-tps-1754-1320.jpg)
### 检测输入内容是否合规
```yaml
serviceSource: "dns"
serviceName: "safecheck"
serviceName: safecheck.dns
servicePort: 443
domain: "green-cip.cn-shanghai.aliyuncs.com"
ak: "XXXXXXXXX"
sk: "XXXXXXXXXXXXXXX"
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkRequest: true
```
### 检测输入与输出是否合规
```yaml
serviceName: safecheck.dns
servicePort: 443
serviceHost: green-cip.cn-shanghai.aliyuncs.com
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkRequest: true
checkResponse: true
```
### 指定自定义内容安全检测服务
用户可能需要根据不同的场景配置不同的检测规则,该问题可通过为不同域名/路由/服务配置不同的内容安全检测服务实现。如下图所示,我们创建了一个名为 llm_query_moderation_01 的检测服务,其中的检测规则在 llm_query_moderation 之上做了一些改动:
![](https://img.alicdn.com/imgextra/i4/O1CN01bAtcvn1N9sB16iiZR_!!6000000001528-0-tps-2728-822.jpg)
接下来在目标域名/路由/服务级别进行以下配置,指定使用我们自定义的 llm_query_moderation_01 中的规则进行检测:
```yaml
serviceName: safecheck.dns
servicePort: 443
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkRequest: true
requestCheckService: llm_query_moderation_01
```
### 配置非openai协议例如百炼App
```yaml
serviceName: safecheck.dns
servicePort: 443
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkRequest: true
checkResponse: true
requestContentJsonPath: "input.prompt"
responseContentJsonPath: "output.text"
denyCode: 200
denyMessage: "很抱歉,我无法回答您的问题"
```
## 可观测
### Metric
ai-security-guard 插件提供了以下监控指标:
- `ai_sec_request_deny`: 请求内容安全检测失败请求数
- `ai_sec_response_deny`: 模型回答安全检测失败请求数
### Trace
如果开启了链路追踪ai-security-guard 插件会在请求 span 中添加以下 attributes:
- `ai_sec_risklabel`: 表示请求命中的风险类型
- `ai_sec_deny_phase`: 表示请求被检测到风险的阶段取值为request或者response
## 请求示例
```bash
curl http://localhost/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "gpt-4o-mini",
"messages": [
{
"role": "user",
"content": "这是一段非法内容"
}
]
}'
```
请求内容会被发送到阿里云内容安全服务进行检测,如果请求内容检测结果为非法,网关将返回形如以下的回答:
```json
{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-4o-mini",
"system_fingerprint": "fp_44709d6fcb",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "作为一名人工智能助手,我不能提供涉及色情、暴力、政治等敏感话题的内容。如果您有其他相关问题,欢迎您提问。",
},
"logprobs": null,
"finish_reason": "stop"
}
]
}
```

View File

@@ -0,0 +1,69 @@
---
title: AI Content Security
keywords: [higress, AI, security]
description: Alibaba Cloud content security
---
## Introduction
Integrate with Aliyun content security service for detections of input and output of LLMs, ensuring that application content is legal and compliant.
## Runtime Properties
Plugin Phase: `CUSTOM`
Plugin Priority: `300`
## Configuration
| Name | Type | Requirement | Default | Description |
| ------------ | ------------ | ------------ | ------------ | ------------ |
| `serviceName` | string | requried | - | service name |
| `servicePort` | string | requried | - | service port |
| `serviceHost` | string | requried | - | Host of Aliyun content security service endpoint |
| `accessKey` | string | requried | - | Aliyun accesskey |
| `secretKey` | string | requried | - | Aliyun secretkey |
| `checkRequest` | bool | optional | false | check if the input is legal |
| `checkResponse` | bool | optional | false | check if the output is legal |
| `requestCheckService` | string | optional | llm_query_moderation | Aliyun yundun service name for input check |
| `responseCheckService` | string | optional | llm_response_moderation | Aliyun yundun service name for output check |
| `requestContentJsonPath` | string | optional | `messages.@reverse.0.content` | Specify the jsonpath of the content to be detected in the request body |
| `responseContentJsonPath` | string | optional | `choices.0.message.content` | Specify the jsonpath of the content to be detected in the response body |
| `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | Specify the jsonpath of the content to be detected in the streaming response body |
| `denyCode` | int | optional | 200 | Response status code when the specified content is illegal |
| `denyMessage` | string | optional | Drainage/non-streaming response in openai format, the answer content is the suggested answer from Alibaba Cloud content security
| Response content when the specified content is illegal |
## Examples of configuration
### Check if the input is legal
```yaml
serviceName: safecheck.dns
servicePort: 443
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkRequest: true
```
### Check if both the input and output are legal
```yaml
serviceName: safecheck.dns
servicePort: 443
serviceHost: green-cip.cn-shanghai.aliyuncs.com
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkRequest: true
checkResponse: true
```
## Observability
### Metric
ai-security-guard plugin provides following metrics:
- `ai_sec_request_deny`: count of requests denied at request phase
- `ai_sec_response_deny`: count of requests denied at response phase
### Trace
ai-security-guard plugin provides following span attributes:
- `ai_sec_risklabel`: risk type of this request
- `ai_sec_deny_phase`: denied phase of this request, value can be request/response

View File

@@ -1,4 +1,4 @@
module myplugin
module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard
go 1.18

View File

@@ -1,14 +1,9 @@
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906 h1:RhEmB+ApLKsClZD7joTC4ifmsVgOVz4pFLdPR3xhNaE=
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906/go.mod h1:10jQXKsYFUF7djs+Oy7t82f4dbie9pISfP9FJwpPLuk=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=

View File

@@ -1,12 +1,12 @@
package main
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
@@ -32,16 +32,47 @@ func main() {
)
}
const (
OpenAIResponseFormat = `{"id": "chatcmpl-123","object": "chat.completion","model": "gpt-4o-mini","choices": [{"index": 0,"message": {"role": "assistant","content": "%s"},"logprobs": null,"finish_reason": "stop"}]}`
OpenAIStreamResponseChunk = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4o-mini", "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4o-mini", "choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd
TracingPrefix = "trace_span_tag."
DefaultRequestCheckService = "llm_query_moderation"
DefaultResponseCheckService = "llm_response_moderation"
DefaultRequestJsonPath = "messages.@reverse.0.content"
DefaultResponseJsonPath = "choices.0.message.content"
DefaultStreamingResponseJsonPath = "choices.0.delta.content"
DefaultDenyCode = 200
AliyunUserAgent = "CIPFrom/AIGateway"
)
type AISecurityConfig struct {
client wrapper.HttpClient
ak string
sk string
checkRequest bool
requestCheckService string
requestContentJsonPath string
checkResponse bool
responseCheckService string
responseContentJsonPath string
responseStreamContentJsonPath string
denyCode int64
denyMessage string
metrics map[string]proxywasm.MetricCounter
}
type StandardResponse struct {
Code int `json:"Code"`
Phase string `json:"BlockPhase"`
Message string `json:"Message"`
func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) {
counter, ok := config.metrics[metricName]
if !ok {
counter = proxywasm.DefineCounterMetric(metricName)
config.metrics[metricName] = counter
}
counter.Increment(inc)
}
func urlEncoding(rawStr string) string {
@@ -71,7 +102,7 @@ func getSign(params map[string]string, secret string) string {
})
canonicalStr := strings.Join(paramArray, "&")
signStr := "POST&%2F&" + urlEncoding(canonicalStr)
fmt.Println(signStr)
// proxywasm.LogInfo(signStr)
return hmacSha1(signStr, secret)
}
@@ -86,32 +117,70 @@ func generateHexID(length int) (string, error) {
func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) error {
serviceName := json.Get("serviceName").String()
servicePort := json.Get("servicePort").Int()
domain := json.Get("domain").String()
config.ak = json.Get("ak").String()
config.sk = json.Get("sk").String()
if serviceName == "" || servicePort == 0 || domain == "" {
serviceHost := json.Get("serviceHost").String()
if serviceName == "" || servicePort == 0 || serviceHost == "" {
return errors.New("invalid service config")
}
config.client = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: serviceName,
config.ak = json.Get("accessKey").String()
config.sk = json.Get("secretKey").String()
if config.ak == "" || config.sk == "" {
return errors.New("invalid AK/SK config")
}
config.checkRequest = json.Get("checkRequest").Bool()
config.checkResponse = json.Get("checkResponse").Bool()
config.denyMessage = json.Get("denyMessage").String()
if obj := json.Get("denyCode"); obj.Exists() {
config.denyCode = obj.Int()
} else {
config.denyCode = DefaultDenyCode
}
if obj := json.Get("requestCheckService"); obj.Exists() {
config.requestCheckService = obj.String()
} else {
config.requestCheckService = DefaultRequestCheckService
}
if obj := json.Get("responseCheckService"); obj.Exists() {
config.responseCheckService = obj.String()
} else {
config.responseCheckService = DefaultResponseCheckService
}
if obj := json.Get("requestContentJsonPath"); obj.Exists() {
config.requestContentJsonPath = obj.String()
} else {
config.requestContentJsonPath = DefaultRequestJsonPath
}
if obj := json.Get("responseContentJsonPath"); obj.Exists() {
config.responseContentJsonPath = obj.String()
} else {
config.responseContentJsonPath = DefaultResponseJsonPath
}
if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() {
config.responseStreamContentJsonPath = obj.String()
} else {
config.responseStreamContentJsonPath = DefaultStreamingResponseJsonPath
}
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: serviceName,
Port: servicePort,
Domain: domain,
Host: serviceHost,
})
config.metrics = make(map[string]proxywasm.MetricCounter)
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
if !config.checkRequest {
ctx.DontReadRequestBody()
}
if !config.checkResponse {
ctx.DontReadResponseBody()
}
return types.ActionContinue
}
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
messages := gjson.GetBytes(body, "messages").Array()
if len(messages) > 0 {
role := messages[len(messages)-1].Get("role").String()
content := messages[len(messages)-1].Get("content").String()
if role != "user" {
return types.ActionContinue
}
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
if content != "" {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
params := map[string]string{
@@ -123,7 +192,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": "llm_query_moderation",
"Service": config.requestCheckService,
"ServiceParameters": `{"content": "` + content + `"}`,
}
signature := getSign(params, config.sk+"&")
@@ -132,31 +201,27 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil,
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
respData := gjson.GetBytes(responseBody, "Data")
if respData.Exists() {
respAdvice := respData.Get("Advice")
respResult := respData.Get("Result")
if respAdvice.Exists() {
sr := StandardResponse{
Code: 403,
Phase: "Request",
Message: respAdvice.Array()[0].Get("Answer").String(),
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request"))
config.incrementCounter("ai_sec_request_deny", 1)
if config.denyMessage != "" {
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(config.denyMessage), -1)
} else {
if gjson.GetBytes(body, "stream").Bool() {
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
jsonData, _ := json.MarshalIndent(sr, "", " ")
label := respResult.Array()[0].Get("Label").String()
proxywasm.SetProperty([]string{"risklabel"}, []byte(label))
proxywasm.SendHttpResponseWithDetail(403, "ai-security-guard.label."+label, [][2]string{{"content-type", "application/json"}}, jsonData, -1)
} else if respResult.Array()[0].Get("Label").String() != "nonLabel" {
sr := StandardResponse{
Code: 403,
Phase: "Request",
Message: "risk detected",
}
jsonData, _ := json.MarshalIndent(sr, "", " ")
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SendHttpResponseWithDetail(403, "ai-security-guard.risk_detected", [][2]string{{"content-type", "application/json"}}, jsonData, -1)
} else {
proxywasm.ResumeHttpRequest()
}
@@ -206,9 +271,16 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
}
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
messages := gjson.GetBytes(body, "choices").Array()
if len(messages) > 0 {
content := messages[0].Get("message").Get("content").String()
hdsMap := ctx.GetContext("headers").(map[string][]string)
isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream")
var content string
if isStreamingResponse {
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
} else {
content = gjson.GetBytes(body, config.responseContentJsonPath).String()
}
log.Debugf("Raw response content is: %s", content)
if len(content) > 0 {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
params := map[string]string{
@@ -220,7 +292,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": "llm_response_moderation",
"Service": config.responseCheckService,
"ServiceParameters": `{"content": "` + content + `"}`,
}
signature := getSign(params, config.sk+"&")
@@ -229,7 +301,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil,
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
defer proxywasm.ResumeHttpResponse()
respData := gjson.GetBytes(responseBody, "Data")
@@ -237,31 +309,23 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
respAdvice := respData.Get("Advice")
respResult := respData.Get("Result")
if respAdvice.Exists() {
sr := StandardResponse{
Code: 403,
Phase: "Response",
Message: respAdvice.Array()[0].Get("Answer").String(),
var jsonData []byte
if config.denyMessage != "" {
jsonData = []byte(config.denyMessage)
} else {
if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
} else {
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
}
}
jsonData, _ := json.MarshalIndent(sr, "", " ")
hdsMap := ctx.GetContext("headers").(map[string][]string)
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{"403"}
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
proxywasm.ReplaceHttpResponseBody(jsonData)
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
} else if respResult.Array()[0].Get("Label").String() != "nonLabel" {
sr := StandardResponse{
Code: 403,
Phase: "Response",
Message: "risk detected",
}
jsonData, _ := json.MarshalIndent(sr, "", " ")
hdsMap := ctx.GetContext("headers").(map[string][]string)
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{"403"}
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
proxywasm.ReplaceHttpResponseBody(jsonData)
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("response"))
config.incrementCounter("ai_sec_response_deny", 1)
}
}
},
@@ -271,3 +335,16 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
return types.ActionContinue
}
}
func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
strChunks := []string{}
for _, chunk := range chunks {
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
jsonObj := gjson.GetBytes(chunk, jsonPath)
if jsonObj.Exists() {
strChunks = append(strChunks, jsonObj.String())
}
}
return strings.Join(strChunks, "")
}