Update ai security guard (#1261)

This commit is contained in:
rinfx
2024-09-24 19:42:34 +08:00
committed by GitHub
parent b82853c653
commit e004321cb0
5 changed files with 350 additions and 88 deletions

View File

@@ -1,22 +1,143 @@
---
title: AI内容安全
keywords: [higress, AI, security]
description: 阿里云内容安全检测
---
## 功能说明
通过对接阿里云内容安全检测大模型的输入输出保障AI应用内容合法合规。
## 运行属性
插件执行阶段:`默认阶段`
插件执行优先级:`300`
## 配置说明
| Name | Type | Requirement | Default | Description |
| :-: | :-: | :-: | :-: | :-: |
| serviceSource | string | requried | - | 服务来源填dns |
| serviceName | string | requried | - | 服务 |
| servicePort | string | requried | - | 服务端口 |
| domain | string | requried | - | 阿里云内容安全endpoint |
| ak | string | requried | - | 阿里云AK |
| sk | string | requried | - | 阿里云SK |
| ------------ | ------------ | ------------ | ------------ | ------------ |
| `serviceName` | string | requried | - | 服务 |
| `servicePort` | string | requried | - | 服务端口 |
| `serviceHost` | string | requried | - | 阿里云内容安全endpoint的域名 |
| `accessKey` | string | requried | - | 阿里云AK |
| `secretKey` | string | requried | - | 阿里云SK |
| `checkRequest` | bool | optional | false | 检查提问内容是否合规 |
| `checkResponse` | bool | optional | false | 检查大模型的回答内容是否合规,生效时会使流式响应变为非流式 |
| `requestCheckService` | string | optional | llm_query_moderation | 指定阿里云内容安全用于检测输入内容的服务 |
| `responseCheckService` | string | optional | llm_response_moderation | 指定阿里云内容安全用于检测输出内容的服务 |
| `requestContentJsonPath` | string | optional | `messages.@reverse.0.content` | 指定要检测内容在请求body中的jsonpath |
| `responseContentJsonPath` | string | optional | `choices.0.message.content` | 指定要检测内容在响应body中的jsonpath |
| `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | 指定要检测内容在流式响应body中的jsonpath |
| `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 |
| `denyMessage` | string | optional | openai格式的流失/非流式响应,回答内容为阿里云内容安全的建议回答 | 指定内容非法时的响应内容 |
## 配置示例
### 前提条件
由于插件中需要调用阿里云内容安全服务所以需要先创建一个DNS类型的服务例如
![](https://img.alicdn.com/imgextra/i4/O1CN013AbDcn1slCY19inU2_!!6000000005806-0-tps-1754-1320.jpg)
### 检测输入内容是否合规
```yaml
serviceSource: "dns"
serviceName: "safecheck"
serviceName: safecheck.dns
servicePort: 443
domain: "green-cip.cn-shanghai.aliyuncs.com"
ak: "XXXXXXXXX"
sk: "XXXXXXXXXXXXXXX"
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkRequest: true
```
### 检测输入与输出是否合规
```yaml
serviceName: safecheck.dns
servicePort: 443
serviceHost: green-cip.cn-shanghai.aliyuncs.com
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkRequest: true
checkResponse: true
```
### 指定自定义内容安全检测服务
用户可能需要根据不同的场景配置不同的检测规则,该问题可通过为不同域名/路由/服务配置不同的内容安全检测服务实现。如下图所示,我们创建了一个名为 llm_query_moderation_01 的检测服务,其中的检测规则在 llm_query_moderation 之上做了一些改动:
![](https://img.alicdn.com/imgextra/i4/O1CN01bAtcvn1N9sB16iiZR_!!6000000001528-0-tps-2728-822.jpg)
接下来在目标域名/路由/服务级别进行以下配置,指定使用我们自定义的 llm_query_moderation_01 中的规则进行检测:
```yaml
serviceName: safecheck.dns
servicePort: 443
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkRequest: true
requestCheckService: llm_query_moderation_01
```
### 配置非openai协议例如百炼App
```yaml
serviceName: safecheck.dns
servicePort: 443
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkRequest: true
checkResponse: true
requestContentJsonPath: "input.prompt"
responseContentJsonPath: "output.text"
denyCode: 200
denyMessage: "很抱歉,我无法回答您的问题"
```
## 可观测
### Metric
ai-security-guard 插件提供了以下监控指标:
- `ai_sec_request_deny`: 请求内容安全检测失败请求数
- `ai_sec_response_deny`: 模型回答安全检测失败请求数
### Trace
如果开启了链路追踪ai-security-guard 插件会在请求 span 中添加以下 attributes:
- `ai_sec_risklabel`: 表示请求命中的风险类型
- `ai_sec_deny_phase`: 表示请求被检测到风险的阶段取值为request或者response
## 请求示例
```bash
curl http://localhost/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "gpt-4o-mini",
"messages": [
{
"role": "user",
"content": "这是一段非法内容"
}
]
}'
```
请求内容会被发送到阿里云内容安全服务进行检测,如果请求内容检测结果为非法,网关将返回形如以下的回答:
```json
{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-4o-mini",
"system_fingerprint": "fp_44709d6fcb",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "作为一名人工智能助手,我不能提供涉及色情、暴力、政治等敏感话题的内容。如果您有其他相关问题,欢迎您提问。",
},
"logprobs": null,
"finish_reason": "stop"
}
]
}
```

View File

@@ -0,0 +1,69 @@
---
title: AI Content Security
keywords: [higress, AI, security]
description: Alibaba Cloud content security
---
## Introduction
Integrate with Aliyun content security service for detections of input and output of LLMs, ensuring that application content is legal and compliant.
## Runtime Properties
Plugin Phase: `CUSTOM`
Plugin Priority: `300`
## Configuration
| Name | Type | Requirement | Default | Description |
| ------------ | ------------ | ------------ | ------------ | ------------ |
| `serviceName` | string | requried | - | service name |
| `servicePort` | string | requried | - | service port |
| `serviceHost` | string | requried | - | Host of Aliyun content security service endpoint |
| `accessKey` | string | requried | - | Aliyun accesskey |
| `secretKey` | string | requried | - | Aliyun secretkey |
| `checkRequest` | bool | optional | false | check if the input is legal |
| `checkResponse` | bool | optional | false | check if the output is legal |
| `requestCheckService` | string | optional | llm_query_moderation | Aliyun yundun service name for input check |
| `responseCheckService` | string | optional | llm_response_moderation | Aliyun yundun service name for output check |
| `requestContentJsonPath` | string | optional | `messages.@reverse.0.content` | Specify the jsonpath of the content to be detected in the request body |
| `responseContentJsonPath` | string | optional | `choices.0.message.content` | Specify the jsonpath of the content to be detected in the response body |
| `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | Specify the jsonpath of the content to be detected in the streaming response body |
| `denyCode` | int | optional | 200 | Response status code when the specified content is illegal |
| `denyMessage` | string | optional | Drainage/non-streaming response in openai format, the answer content is the suggested answer from Alibaba Cloud content security
| Response content when the specified content is illegal |
## Examples of configuration
### Check if the input is legal
```yaml
serviceName: safecheck.dns
servicePort: 443
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkRequest: true
```
### Check if both the input and output are legal
```yaml
serviceName: safecheck.dns
servicePort: 443
serviceHost: green-cip.cn-shanghai.aliyuncs.com
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkRequest: true
checkResponse: true
```
## Observability
### Metric
ai-security-guard plugin provides following metrics:
- `ai_sec_request_deny`: count of requests denied at request phase
- `ai_sec_response_deny`: count of requests denied at response phase
### Trace
ai-security-guard plugin provides following span attributes:
- `ai_sec_risklabel`: risk type of this request
- `ai_sec_deny_phase`: denied phase of this request, value can be request/response

View File

@@ -1,4 +1,4 @@
module myplugin
module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard
go 1.18

View File

@@ -1,14 +1,9 @@
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906 h1:RhEmB+ApLKsClZD7joTC4ifmsVgOVz4pFLdPR3xhNaE=
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906/go.mod h1:10jQXKsYFUF7djs+Oy7t82f4dbie9pISfP9FJwpPLuk=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=

View File

@@ -1,12 +1,12 @@
package main
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
@@ -32,16 +32,47 @@ func main() {
)
}
const (
OpenAIResponseFormat = `{"id": "chatcmpl-123","object": "chat.completion","model": "gpt-4o-mini","choices": [{"index": 0,"message": {"role": "assistant","content": "%s"},"logprobs": null,"finish_reason": "stop"}]}`
OpenAIStreamResponseChunk = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4o-mini", "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4o-mini", "choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd
TracingPrefix = "trace_span_tag."
DefaultRequestCheckService = "llm_query_moderation"
DefaultResponseCheckService = "llm_response_moderation"
DefaultRequestJsonPath = "messages.@reverse.0.content"
DefaultResponseJsonPath = "choices.0.message.content"
DefaultStreamingResponseJsonPath = "choices.0.delta.content"
DefaultDenyCode = 200
AliyunUserAgent = "CIPFrom/AIGateway"
)
type AISecurityConfig struct {
client wrapper.HttpClient
ak string
sk string
client wrapper.HttpClient
ak string
sk string
checkRequest bool
requestCheckService string
requestContentJsonPath string
checkResponse bool
responseCheckService string
responseContentJsonPath string
responseStreamContentJsonPath string
denyCode int64
denyMessage string
metrics map[string]proxywasm.MetricCounter
}
type StandardResponse struct {
Code int `json:"Code"`
Phase string `json:"BlockPhase"`
Message string `json:"Message"`
func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) {
counter, ok := config.metrics[metricName]
if !ok {
counter = proxywasm.DefineCounterMetric(metricName)
config.metrics[metricName] = counter
}
counter.Increment(inc)
}
func urlEncoding(rawStr string) string {
@@ -71,7 +102,7 @@ func getSign(params map[string]string, secret string) string {
})
canonicalStr := strings.Join(paramArray, "&")
signStr := "POST&%2F&" + urlEncoding(canonicalStr)
fmt.Println(signStr)
// proxywasm.LogInfo(signStr)
return hmacSha1(signStr, secret)
}
@@ -86,32 +117,70 @@ func generateHexID(length int) (string, error) {
func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) error {
serviceName := json.Get("serviceName").String()
servicePort := json.Get("servicePort").Int()
domain := json.Get("domain").String()
config.ak = json.Get("ak").String()
config.sk = json.Get("sk").String()
if serviceName == "" || servicePort == 0 || domain == "" {
serviceHost := json.Get("serviceHost").String()
if serviceName == "" || servicePort == 0 || serviceHost == "" {
return errors.New("invalid service config")
}
config.client = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: serviceName,
Port: servicePort,
Domain: domain,
config.ak = json.Get("accessKey").String()
config.sk = json.Get("secretKey").String()
if config.ak == "" || config.sk == "" {
return errors.New("invalid AK/SK config")
}
config.checkRequest = json.Get("checkRequest").Bool()
config.checkResponse = json.Get("checkResponse").Bool()
config.denyMessage = json.Get("denyMessage").String()
if obj := json.Get("denyCode"); obj.Exists() {
config.denyCode = obj.Int()
} else {
config.denyCode = DefaultDenyCode
}
if obj := json.Get("requestCheckService"); obj.Exists() {
config.requestCheckService = obj.String()
} else {
config.requestCheckService = DefaultRequestCheckService
}
if obj := json.Get("responseCheckService"); obj.Exists() {
config.responseCheckService = obj.String()
} else {
config.responseCheckService = DefaultResponseCheckService
}
if obj := json.Get("requestContentJsonPath"); obj.Exists() {
config.requestContentJsonPath = obj.String()
} else {
config.requestContentJsonPath = DefaultRequestJsonPath
}
if obj := json.Get("responseContentJsonPath"); obj.Exists() {
config.responseContentJsonPath = obj.String()
} else {
config.responseContentJsonPath = DefaultResponseJsonPath
}
if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() {
config.responseStreamContentJsonPath = obj.String()
} else {
config.responseStreamContentJsonPath = DefaultStreamingResponseJsonPath
}
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: serviceName,
Port: servicePort,
Host: serviceHost,
})
config.metrics = make(map[string]proxywasm.MetricCounter)
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
if !config.checkRequest {
ctx.DontReadRequestBody()
}
if !config.checkResponse {
ctx.DontReadResponseBody()
}
return types.ActionContinue
}
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
messages := gjson.GetBytes(body, "messages").Array()
if len(messages) > 0 {
role := messages[len(messages)-1].Get("role").String()
content := messages[len(messages)-1].Get("content").String()
if role != "user" {
return types.ActionContinue
}
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
if content != "" {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
params := map[string]string{
@@ -123,7 +192,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": "llm_query_moderation",
"Service": config.requestCheckService,
"ServiceParameters": `{"content": "` + content + `"}`,
}
signature := getSign(params, config.sk+"&")
@@ -132,31 +201,27 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil,
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
respData := gjson.GetBytes(responseBody, "Data")
if respData.Exists() {
respAdvice := respData.Get("Advice")
respResult := respData.Get("Result")
if respAdvice.Exists() {
sr := StandardResponse{
Code: 403,
Phase: "Request",
Message: respAdvice.Array()[0].Get("Answer").String(),
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request"))
config.incrementCounter("ai_sec_request_deny", 1)
if config.denyMessage != "" {
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(config.denyMessage), -1)
} else {
if gjson.GetBytes(body, "stream").Bool() {
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
}
jsonData, _ := json.MarshalIndent(sr, "", " ")
label := respResult.Array()[0].Get("Label").String()
proxywasm.SetProperty([]string{"risklabel"}, []byte(label))
proxywasm.SendHttpResponseWithDetail(403, "ai-security-guard.label."+label, [][2]string{{"content-type", "application/json"}}, jsonData, -1)
} else if respResult.Array()[0].Get("Label").String() != "nonLabel" {
sr := StandardResponse{
Code: 403,
Phase: "Request",
Message: "risk detected",
}
jsonData, _ := json.MarshalIndent(sr, "", " ")
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SendHttpResponseWithDetail(403, "ai-security-guard.risk_detected", [][2]string{{"content-type", "application/json"}}, jsonData, -1)
} else {
proxywasm.ResumeHttpRequest()
}
@@ -206,9 +271,16 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
}
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
messages := gjson.GetBytes(body, "choices").Array()
if len(messages) > 0 {
content := messages[0].Get("message").Get("content").String()
hdsMap := ctx.GetContext("headers").(map[string][]string)
isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream")
var content string
if isStreamingResponse {
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
} else {
content = gjson.GetBytes(body, config.responseContentJsonPath).String()
}
log.Debugf("Raw response content is: %s", content)
if len(content) > 0 {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
params := map[string]string{
@@ -220,7 +292,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": "llm_response_moderation",
"Service": config.responseCheckService,
"ServiceParameters": `{"content": "` + content + `"}`,
}
signature := getSign(params, config.sk+"&")
@@ -229,7 +301,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil,
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
defer proxywasm.ResumeHttpResponse()
respData := gjson.GetBytes(responseBody, "Data")
@@ -237,31 +309,23 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
respAdvice := respData.Get("Advice")
respResult := respData.Get("Result")
if respAdvice.Exists() {
sr := StandardResponse{
Code: 403,
Phase: "Response",
Message: respAdvice.Array()[0].Get("Answer").String(),
var jsonData []byte
if config.denyMessage != "" {
jsonData = []byte(config.denyMessage)
} else {
if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
} else {
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
}
}
jsonData, _ := json.MarshalIndent(sr, "", " ")
hdsMap := ctx.GetContext("headers").(map[string][]string)
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{"403"}
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
proxywasm.ReplaceHttpResponseBody(jsonData)
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
} else if respResult.Array()[0].Get("Label").String() != "nonLabel" {
sr := StandardResponse{
Code: 403,
Phase: "Response",
Message: "risk detected",
}
jsonData, _ := json.MarshalIndent(sr, "", " ")
hdsMap := ctx.GetContext("headers").(map[string][]string)
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{"403"}
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
proxywasm.ReplaceHttpResponseBody(jsonData)
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("response"))
config.incrementCounter("ai_sec_response_deny", 1)
}
}
},
@@ -271,3 +335,16 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
return types.ActionContinue
}
}
func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
strChunks := []string{}
for _, chunk := range chunks {
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
jsonObj := gjson.GetBytes(chunk, jsonPath)
if jsonObj.Exists() {
strChunks = append(strChunks, jsonObj.String())
}
}
return strings.Join(strChunks, "")
}