[feature] add checking of maliciousUrl & modelHallucination, and adjust consumer specific configs (#3024)

This commit is contained in:
rinfx
2025-10-28 14:12:54 +08:00
committed by GitHub
parent 2076ded06f
commit 2a320f87a6
4 changed files with 365 additions and 105 deletions

View File

@@ -37,8 +37,9 @@ description: 阿里云内容安全检测
| `sensitiveDataLevelBar` | string | optional | S4 | 敏感内容检测拦截风险等级,取值为 `S4`, `S3`, `S2` or `S1` |
| `timeout` | int | optional | 2000 | 调用内容安全服务时的超时时间 |
| `bufferLimit` | int | optional | 1000 | 调用内容安全服务时每段文本的长度限制 |
| `consumerSpecificRequestCheckService` | map | optional | - | 为不同消费者指定特定的请求检测服务 |
| `consumerSpecificResponseCheckService` | map | optional | - | 为不同消费者指定特定的响应检测服务 |
| `consumerRequestCheckService` | map | optional | - | 为不同消费者指定特定的请求检测服务 |
| `consumerResponseCheckService` | map | optional | - | 为不同消费者指定特定的响应检测服务 |
| `consumerRiskLevel` | map | optional | - | 为不同消费者指定各维度的拦截风险等级 |
补充说明一下 `denyMessage`,对非法请求的处理逻辑为:
- 如果配置了 `denyMessage`,返回内容为 `denyMessage` 配置内容格式为openai格式的流式/非流式响应
@@ -70,6 +71,20 @@ description: 阿里云内容安全检测
![](https://img.alicdn.com/imgextra/i4/O1CN013AbDcn1slCY19inU2_!!6000000005806-0-tps-1754-1320.jpg)
阿里云内容安全配置示例:
```yaml
requestCheckService: llm_query_moderation
responseCheckService: llm_response_moderation
```
阿里云AI安全护栏配置示例
```yaml
requestCheckService: query_security_check
responseCheckService: response_security_check
```
### 检测输入内容是否合规
```yaml

View File

@@ -37,6 +37,9 @@ Plugin Priority: `300`
| `sensitiveDataLevelBar` | string | optional | S4 | sensitiveData risk level threshold, `S4`, `S3`, `S2` or `S1` |
| `timeout` | int | optional | 2000 | timeout for lvwang service |
| `bufferLimit` | int | optional | 1000 | Limit the length of each text when calling the lvwang service |
| `consumerRequestCheckService` | map | optional | - | Specify specific request detection services for different consumers |
| `consumerResponseCheckService` | map | optional | - | Specify specific response detection services for different consumers |
| `consumerRiskLevel` | map | optional | - | Specify interception risk levels for different consumers in different dimensions |
## Examples of configuration

View File

@@ -13,6 +13,7 @@ import (
mrand "math/rand"
"net/http"
"net/url"
"regexp"
"sort"
"strings"
"time"
@@ -51,9 +52,11 @@ const (
S1Sensitive = "S1"
NoSensitive = "S0"
ContentModerationType = "contentModeration"
PromptAttackType = "promptAttack"
SensitiveDataType = "sensitiveData"
ContentModerationType = "contentModeration"
PromptAttackType = "promptAttack"
SensitiveDataType = "sensitiveData"
MaliciousUrlDataType = "maliciousUrl"
ModelHallucinationDataType = "modelHallucination"
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
@@ -108,30 +111,51 @@ type Detail struct {
}
type AISecurityConfig struct {
client wrapper.HttpClient
ak string
sk string
token string
action string
checkRequest bool
requestCheckService string
requestContentJsonPath string
checkResponse bool
responseCheckService string
responseContentJsonPath string
responseStreamContentJsonPath string
denyCode int64
denyMessage string
protocolOriginal bool
riskLevelBar string
contentModerationLevelBar string
promptAttackLevelBar string
sensitiveDataLevelBar string
timeout uint32
bufferLimit int
metrics map[string]proxywasm.MetricCounter
consumerSpecificRequestCheckService map[string]string
consumerSpecificResponseCheckService map[string]string
client wrapper.HttpClient
ak string
sk string
token string
action string
checkRequest bool
requestCheckService string
requestContentJsonPath string
checkResponse bool
responseCheckService string
responseContentJsonPath string
responseStreamContentJsonPath string
denyCode int64
denyMessage string
protocolOriginal bool
riskLevelBar string
contentModerationLevelBar string
promptAttackLevelBar string
sensitiveDataLevelBar string
maliciousUrlLevelBar string
modelHallucinationLevelBar string
timeout uint32
bufferLimit int
metrics map[string]proxywasm.MetricCounter
consumerRequestCheckService []map[string]interface{}
consumerResponseCheckService []map[string]interface{}
consumerRiskLevel []map[string]interface{}
}
type Matcher struct {
Exact string
Prefix string
Re *regexp.Regexp
}
func (m *Matcher) match(consumer string) bool {
if m.Exact != "" {
return consumer == m.Exact
} else if m.Prefix != "" {
return strings.HasPrefix(consumer, m.Prefix)
} else if m.Re != nil {
return m.Re.MatchString(consumer)
} else {
return false
}
}
func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) {
@@ -143,6 +167,126 @@ func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64)
counter.Increment(inc)
}
func (config *AISecurityConfig) getRequestCheckService(consumer string) string {
result := config.requestCheckService
for _, obj := range config.consumerRequestCheckService {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if requestCheckService, ok := obj["requestCheckService"]; ok {
result, _ = requestCheckService.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) getResponseCheckService(consumer string) string {
result := config.responseCheckService
for _, obj := range config.consumerResponseCheckService {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if responseCheckService, ok := obj["responseCheckService"]; ok {
result, _ = responseCheckService.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) getRiskLevelBar(consumer string) string {
result := config.riskLevelBar
for _, obj := range config.consumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if riskLevelBar, ok := obj["riskLevelBar"]; ok {
result, _ = riskLevelBar.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) getContentModerationLevelBar(consumer string) string {
result := config.contentModerationLevelBar
for _, obj := range config.consumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if contentModerationLevelBar, ok := obj["contentModerationLevelBar"]; ok {
result, _ = contentModerationLevelBar.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) getPromptAttackLevelBar(consumer string) string {
result := config.promptAttackLevelBar
for _, obj := range config.consumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if promptAttackLevelBar, ok := obj["promptAttackLevelBar"]; ok {
result, _ = promptAttackLevelBar.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) getSensitiveDataLevelBar(consumer string) string {
result := config.sensitiveDataLevelBar
for _, obj := range config.consumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if sensitiveDataLevelBar, ok := obj["sensitiveDataLevelBar"]; ok {
result, _ = sensitiveDataLevelBar.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) getMaliciousUrlLevelBar(consumer string) string {
result := config.maliciousUrlLevelBar
for _, obj := range config.consumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if maliciousUrlLevelBar, ok := obj["maliciousUrlLevelBar"]; ok {
result, _ = maliciousUrlLevelBar.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) getModelHallucinationLevelBar(consumer string) string {
result := config.modelHallucinationLevelBar
for _, obj := range config.consumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if modelHallucinationLevelBar, ok := obj["modelHallucinationLevelBar"]; ok {
result, _ = modelHallucinationLevelBar.(string)
}
break
}
}
}
return result
}
func levelToInt(riskLevel string) int {
// First check against our defined constants
switch riskLevel {
@@ -195,14 +339,14 @@ func levelToInt(riskLevel string) int {
}
}
func isRiskLevelAcceptable(action string, data Data, config AISecurityConfig) bool {
func isRiskLevelAcceptable(action string, data Data, config AISecurityConfig, consumer string) bool {
if action == "MultiModalGuard" {
// Check top-level risk levels for MultiModalGuard
if levelToInt(data.RiskLevel) >= levelToInt(config.contentModerationLevelBar) {
if levelToInt(data.RiskLevel) >= levelToInt(config.getContentModerationLevelBar(consumer)) {
return false
}
// Also check AttackLevel for prompt attack detection
if levelToInt(data.AttackLevel) >= levelToInt(config.promptAttackLevelBar) {
if levelToInt(data.AttackLevel) >= levelToInt(config.getPromptAttackLevelBar(consumer)) {
return false
}
@@ -210,22 +354,30 @@ func isRiskLevelAcceptable(action string, data Data, config AISecurityConfig) bo
for _, detail := range data.Detail {
switch detail.Type {
case ContentModerationType:
if levelToInt(detail.Level) >= levelToInt(config.contentModerationLevelBar) {
if levelToInt(detail.Level) >= levelToInt(config.getContentModerationLevelBar(consumer)) {
return false
}
case PromptAttackType:
if levelToInt(detail.Level) >= levelToInt(config.promptAttackLevelBar) {
if levelToInt(detail.Level) >= levelToInt(config.getPromptAttackLevelBar(consumer)) {
return false
}
case SensitiveDataType:
if levelToInt(detail.Level) >= levelToInt(config.sensitiveDataLevelBar) {
if levelToInt(detail.Level) >= levelToInt(config.getSensitiveDataLevelBar(consumer)) {
return false
}
case MaliciousUrlDataType:
if levelToInt(detail.Level) >= levelToInt(config.getMaliciousUrlLevelBar(consumer)) {
return false
}
case ModelHallucinationDataType:
if levelToInt(detail.Level) >= levelToInt(config.getModelHallucinationLevelBar(consumer)) {
return false
}
}
}
return true
} else {
return levelToInt(data.RiskLevel) < levelToInt(config.riskLevelBar)
return levelToInt(data.RiskLevel) < levelToInt(config.getRiskLevelBar(consumer))
}
}
@@ -351,6 +503,22 @@ func parseConfig(json gjson.Result, config *AISecurityConfig) error {
} else {
config.sensitiveDataLevelBar = S4Sensitive
}
if obj := json.Get("modelHallucinationLevelBar"); obj.Exists() {
config.modelHallucinationLevelBar = obj.String()
if levelToInt(config.modelHallucinationLevelBar) <= 0 {
return errors.New("invalid modelHallucinationLevelBar, value must be one of [max, high, medium, low]")
}
} else {
config.modelHallucinationLevelBar = MaxRisk
}
if obj := json.Get("maliciousUrlLevelBar"); obj.Exists() {
config.maliciousUrlLevelBar = obj.String()
if levelToInt(config.maliciousUrlLevelBar) <= 0 {
return errors.New("invalid maliciousUrlLevelBar, value must be one of [max, high, medium, low]")
}
} else {
config.maliciousUrlLevelBar = MaxRisk
}
if obj := json.Get("timeout"); obj.Exists() {
config.timeout = uint32(obj.Int())
} else {
@@ -361,13 +529,71 @@ func parseConfig(json gjson.Result, config *AISecurityConfig) error {
} else {
config.bufferLimit = 1000
}
config.consumerSpecificRequestCheckService = make(map[string]string)
for k, v := range json.Get("consumerSpecificRequestCheckService").Map() {
config.consumerSpecificRequestCheckService[k] = v.String()
if obj := json.Get("consumerRequestCheckService"); obj.Exists() {
for _, item := range json.Get("consumerRequestCheckService").Array() {
m := make(map[string]interface{})
for k, v := range item.Map() {
m[k] = v.Value()
}
consumerName, ok1 := m["name"]
matchType, ok2 := m["matchType"]
if !ok1 || !ok2 {
continue
}
switch fmt.Sprint(matchType) {
case "exact":
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
case "prefix":
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
case "regexp":
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
}
config.consumerRequestCheckService = append(config.consumerRequestCheckService, m)
}
}
config.consumerSpecificResponseCheckService = make(map[string]string)
for k, v := range json.Get("consumerSpecificResponseCheckService").Map() {
config.consumerSpecificResponseCheckService[k] = v.String()
if obj := json.Get("consumerResponseCheckService"); obj.Exists() {
for _, item := range json.Get("consumerResponseCheckService").Array() {
m := make(map[string]interface{})
for k, v := range item.Map() {
m[k] = v.Value()
}
consumerName, ok1 := m["name"]
matchType, ok2 := m["matchType"]
if !ok1 || !ok2 {
continue
}
switch fmt.Sprint(matchType) {
case "exact":
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
case "prefix":
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
case "regexp":
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
}
config.consumerResponseCheckService = append(config.consumerResponseCheckService, m)
}
}
if obj := json.Get("consumerRiskLevel"); obj.Exists() {
for _, item := range json.Get("consumerRiskLevel").Array() {
m := make(map[string]interface{})
for k, v := range item.Map() {
m[k] = v.Value()
}
consumerName, ok1 := m["name"]
matchType, ok2 := m["matchType"]
if !ok1 || !ok2 {
continue
}
switch fmt.Sprint(matchType) {
case "exact":
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
case "prefix":
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
case "regexp":
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
}
config.consumerRiskLevel = append(config.consumerRiskLevel, m)
}
}
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: serviceName,
@@ -399,6 +625,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig) type
}
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
log.Debugf("checking request body...")
startTime := time.Now().UnixMilli()
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
@@ -423,7 +650,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
proxywasm.ResumeHttpRequest()
return
}
if isRiskLevelAcceptable(config.action, response.Data, config) {
if isRiskLevelAcceptable(config.action, response.Data, config, consumer) {
if contentIndex >= len(content) {
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
@@ -441,7 +668,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := marshalStr(denyMessage)
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
if config.protocolOriginal {
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
@@ -476,11 +703,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
consumer, _ := ctx.GetContext("consumer").(string)
checkService, ok := config.consumerSpecificRequestCheckService[consumer]
if !ok {
checkService = config.requestCheckService
}
checkService := config.getRequestCheckService(consumer)
params := map[string]string{
"Format": "JSON",
"Version": "2022-03-02",
@@ -491,7 +714,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": checkService,
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, marshalStr(contentPiece), AliyunUserAgent),
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, wrapper.MarshalStr(contentPiece), AliyunUserAgent),
}
if config.token != "" {
params["SecurityToken"] = config.token
@@ -540,6 +763,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig) typ
}
func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, data []byte, endOfStream bool) []byte {
consumer, _ := ctx.GetContext("consumer").(string)
var bufferQueue [][]byte
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
@@ -561,14 +785,14 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi
ctx.SetContext("during_call", false)
return
}
if !isRiskLevelAcceptable(config.action, response.Data, config) {
if !isRiskLevelAcceptable(config.action, response.Data, config, consumer) {
denyMessage := DefaultDenyMessage
if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = "\n" + response.Data.Advice[0].Answer
} else if config.denyMessage != "" {
denyMessage = config.denyMessage
}
marshalledDenyMessage := marshalStr(denyMessage)
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.InjectEncodedDataToFilterChain(jsonData, true)
@@ -587,7 +811,6 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi
return
}
if ctx.BufferQueueSize() >= config.bufferLimit || ctx.GetContext("end_of_stream_received").(bool) {
ctx.SetContext("during_call", true)
var buffer string
for ctx.BufferQueueSize() > 0 {
front := ctx.PopBuffer()
@@ -598,14 +821,16 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi
break
}
}
// if streaming body has reasoning_content, buffer maybe empty
log.Debugf("current content piece: %s", buffer)
if len(buffer) == 0 {
return
}
ctx.SetContext("during_call", true)
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
log.Debugf("current content piece: %s", buffer)
consumer, _ := ctx.GetContext("consumer").(string)
checkService, ok := config.consumerSpecificResponseCheckService[consumer]
if !ok {
checkService = config.responseCheckService
}
checkService := config.getResponseCheckService(consumer)
params := map[string]string{
"Format": "JSON",
"Version": "2022-03-02",
@@ -616,7 +841,7 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": checkService,
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, ctx.GetContext("sessionID").(string), marshalStr(buffer), AliyunUserAgent),
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, ctx.GetContext("sessionID").(string), wrapper.MarshalStr(buffer), AliyunUserAgent),
}
if config.token != "" {
params["SecurityToken"] = config.token
@@ -637,7 +862,9 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi
}
}
if !ctx.GetContext("risk_detected").(bool) {
ctx.PushBuffer(data)
for _, chunk := range bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) {
ctx.PushBuffer([]byte(string(chunk) + "\n\n"))
}
ctx.SetContext("end_of_stream_received", endOfStream)
if !ctx.GetContext("during_call").(bool) {
singleCall()
@@ -649,6 +876,7 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi
}
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
log.Debugf("checking response body...")
startTime := time.Now().UnixMilli()
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
@@ -680,7 +908,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
proxywasm.ResumeHttpResponse()
return
}
if isRiskLevelAcceptable(config.action, response.Data, config) {
if isRiskLevelAcceptable(config.action, response.Data, config, consumer) {
if contentIndex >= len(content) {
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
@@ -698,7 +926,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := marshalStr(denyMessage)
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
if config.protocolOriginal {
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if isStreamingResponse {
@@ -732,11 +960,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
consumer, _ := ctx.GetContext("consumer").(string)
checkService, ok := config.consumerSpecificResponseCheckService[consumer]
if !ok {
checkService = config.responseCheckService
}
checkService := config.getResponseCheckService(consumer)
params := map[string]string{
"Format": "JSON",
"Version": "2022-03-02",
@@ -747,7 +971,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": checkService,
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, marshalStr(contentPiece), AliyunUserAgent),
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, wrapper.MarshalStr(contentPiece), AliyunUserAgent),
}
if config.token != "" {
params["SecurityToken"] = config.token
@@ -769,7 +993,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
}
func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n"))
strChunks := []string{}
for _, chunk := range chunks {
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
@@ -777,17 +1001,3 @@ func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
}
return strings.Join(strChunks, "")
}
func marshalStr(raw string) string {
helper := map[string]string{
"placeholder": raw,
}
marshalledHelper, _ := json.Marshal(helper)
marshalledRaw := gjson.GetBytes(marshalledHelper, "placeholder").Raw
if len(marshalledRaw) >= 2 {
return marshalledRaw[1 : len(marshalledRaw)-1]
} else {
log.Errorf("failed to marshal json string, raw string is: %s", raw)
return ""
}
}

View File

@@ -96,6 +96,42 @@ var missingAuthConfig = func() json.RawMessage {
return data
}()
// 测试配置:消费者级别特殊配置
var consumerSpecificConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "security-service",
"servicePort": 8080,
"serviceHost": "security.example.com",
"accessKey": "test-ak",
"secretKey": "test-sk",
"checkRequest": true,
"checkResponse": false,
"contentModerationLevelBar": "high",
"promptAttackLevelBar": "high",
"sensitiveDataLevelBar": "S3",
"maliciousUrlLevelBar": "high",
"modelHallucinationLevelBar": "high",
"timeout": 1000,
"bufferLimit": 500,
"consumerRequestCheckService": map[string]interface{}{
"name": "aaa",
"matchType": "exact",
"requestCheckService": "llm_query_moderation_1",
},
"consumerResponseCheckService": map[string]interface{}{
"name": "bbb",
"matchType": "prefix",
"responseCheckService": "llm_response_moderation_1",
},
"consumerRiskLevel": map[string]interface{}{
"name": "ccc.*",
"matchType": "regexp",
"maliciousUrlLevelBar": "low",
},
})
return data
}()
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基础配置解析
@@ -156,6 +192,24 @@ func TestParseConfig(t *testing.T) {
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// 测试消费者级别配置
t.Run("consumer specific config", func(t *testing.T) {
host, status := test.NewTestHost(consumerSpecificConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*AISecurityConfig)
require.Equal(t, "llm_query_moderation", securityConfig.getRequestCheckService("aaaa"))
require.Equal(t, "llm_query_moderation_1", securityConfig.getRequestCheckService("aaa"))
require.Equal(t, "llm_response_moderation", securityConfig.getResponseCheckService("bb"))
require.Equal(t, "llm_response_moderation_1", securityConfig.getResponseCheckService("bbb-prefix-test"))
require.Equal(t, "high", securityConfig.getMaliciousUrlLevelBar("cc"))
require.Equal(t, "low", securityConfig.getMaliciousUrlLevelBar("ccc-regexp-test"))
})
})
}
@@ -400,25 +454,3 @@ func TestUtilityFunctions(t *testing.T) {
require.Len(t, id, 38) // "chatcmpl-" + 29 random chars
})
}
func TestMarshalFunctions(t *testing.T) {
// 测试marshalStr函数
t.Run("marshal string", func(t *testing.T) {
testStr := "Hello, World!"
marshalled := marshalStr(testStr)
require.Equal(t, testStr, marshalled)
})
// 测试extractMessageFromStreamingBody函数
t.Run("extract streaming body", func(t *testing.T) {
// 使用正确的分隔符每个chunk之间用双换行符分隔
streamingData := []byte(`{"choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"}}]}
{"choices":[{"index":0,"delta":{"role":"assistant","content":" World"}}]}
{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)
extracted := extractMessageFromStreamingBody(streamingData, "choices.0.delta.content")
require.Equal(t, "Hello World", extracted)
})
}