mirror of
https://github.com/alibaba/higress.git
synced 2026-03-06 17:40:51 +08:00
[feature] add checking of maliciousUrl & modelHallucination, and adjust consumer specific configs (#3024)
This commit is contained in:
@@ -37,8 +37,9 @@ description: 阿里云内容安全检测
|
||||
| `sensitiveDataLevelBar` | string | optional | S4 | 敏感内容检测拦截风险等级,取值为 `S4`, `S3`, `S2` or `S1` |
|
||||
| `timeout` | int | optional | 2000 | 调用内容安全服务时的超时时间 |
|
||||
| `bufferLimit` | int | optional | 1000 | 调用内容安全服务时每段文本的长度限制 |
|
||||
| `consumerSpecificRequestCheckService` | map | optional | - | 为不同消费者指定特定的请求检测服务 |
|
||||
| `consumerSpecificResponseCheckService` | map | optional | - | 为不同消费者指定特定的响应检测服务 |
|
||||
| `consumerRequestCheckService` | map | optional | - | 为不同消费者指定特定的请求检测服务 |
|
||||
| `consumerResponseCheckService` | map | optional | - | 为不同消费者指定特定的响应检测服务 |
|
||||
| `consumerRiskLevel` | map | optional | - | 为不同消费者指定各维度的拦截风险等级 |
|
||||
|
||||
补充说明一下 `denyMessage`,对非法请求的处理逻辑为:
|
||||
- 如果配置了 `denyMessage`,返回内容为 `denyMessage` 配置内容,格式为openai格式的流式/非流式响应
|
||||
@@ -70,6 +71,20 @@ description: 阿里云内容安全检测
|
||||
|
||||

|
||||
|
||||
阿里云内容安全配置示例:
|
||||
|
||||
```yaml
|
||||
requestCheckService: llm_query_moderation
|
||||
responseCheckService: llm_response_moderation
|
||||
```
|
||||
|
||||
阿里云AI安全护栏配置示例:
|
||||
|
||||
```yaml
|
||||
requestCheckService: query_security_check
|
||||
responseCheckService: response_security_check
|
||||
```
|
||||
|
||||
### 检测输入内容是否合规
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -37,6 +37,9 @@ Plugin Priority: `300`
|
||||
| `sensitiveDataLevelBar` | string | optional | S4 | sensitiveData risk level threshold, `S4`, `S3`, `S2` or `S1` |
|
||||
| `timeout` | int | optional | 2000 | timeout for lvwang service |
|
||||
| `bufferLimit` | int | optional | 1000 | Limit the length of each text when calling the lvwang service |
|
||||
| `consumerRequestCheckService` | map | optional | - | Specify specific request detection services for different consumers |
|
||||
| `consumerResponseCheckService` | map | optional | - | Specify specific response detection services for different consumers |
|
||||
| `consumerRiskLevel` | map | optional | - | Specify interception risk levels for different consumers in different dimensions |
|
||||
|
||||
|
||||
## Examples of configuration
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
mrand "math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -51,9 +52,11 @@ const (
|
||||
S1Sensitive = "S1"
|
||||
NoSensitive = "S0"
|
||||
|
||||
ContentModerationType = "contentModeration"
|
||||
PromptAttackType = "promptAttack"
|
||||
SensitiveDataType = "sensitiveData"
|
||||
ContentModerationType = "contentModeration"
|
||||
PromptAttackType = "promptAttack"
|
||||
SensitiveDataType = "sensitiveData"
|
||||
MaliciousUrlDataType = "maliciousUrl"
|
||||
ModelHallucinationDataType = "modelHallucination"
|
||||
|
||||
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
|
||||
@@ -108,30 +111,51 @@ type Detail struct {
|
||||
}
|
||||
|
||||
type AISecurityConfig struct {
|
||||
client wrapper.HttpClient
|
||||
ak string
|
||||
sk string
|
||||
token string
|
||||
action string
|
||||
checkRequest bool
|
||||
requestCheckService string
|
||||
requestContentJsonPath string
|
||||
checkResponse bool
|
||||
responseCheckService string
|
||||
responseContentJsonPath string
|
||||
responseStreamContentJsonPath string
|
||||
denyCode int64
|
||||
denyMessage string
|
||||
protocolOriginal bool
|
||||
riskLevelBar string
|
||||
contentModerationLevelBar string
|
||||
promptAttackLevelBar string
|
||||
sensitiveDataLevelBar string
|
||||
timeout uint32
|
||||
bufferLimit int
|
||||
metrics map[string]proxywasm.MetricCounter
|
||||
consumerSpecificRequestCheckService map[string]string
|
||||
consumerSpecificResponseCheckService map[string]string
|
||||
client wrapper.HttpClient
|
||||
ak string
|
||||
sk string
|
||||
token string
|
||||
action string
|
||||
checkRequest bool
|
||||
requestCheckService string
|
||||
requestContentJsonPath string
|
||||
checkResponse bool
|
||||
responseCheckService string
|
||||
responseContentJsonPath string
|
||||
responseStreamContentJsonPath string
|
||||
denyCode int64
|
||||
denyMessage string
|
||||
protocolOriginal bool
|
||||
riskLevelBar string
|
||||
contentModerationLevelBar string
|
||||
promptAttackLevelBar string
|
||||
sensitiveDataLevelBar string
|
||||
maliciousUrlLevelBar string
|
||||
modelHallucinationLevelBar string
|
||||
timeout uint32
|
||||
bufferLimit int
|
||||
metrics map[string]proxywasm.MetricCounter
|
||||
consumerRequestCheckService []map[string]interface{}
|
||||
consumerResponseCheckService []map[string]interface{}
|
||||
consumerRiskLevel []map[string]interface{}
|
||||
}
|
||||
|
||||
type Matcher struct {
|
||||
Exact string
|
||||
Prefix string
|
||||
Re *regexp.Regexp
|
||||
}
|
||||
|
||||
func (m *Matcher) match(consumer string) bool {
|
||||
if m.Exact != "" {
|
||||
return consumer == m.Exact
|
||||
} else if m.Prefix != "" {
|
||||
return strings.HasPrefix(consumer, m.Prefix)
|
||||
} else if m.Re != nil {
|
||||
return m.Re.MatchString(consumer)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) {
|
||||
@@ -143,6 +167,126 @@ func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64)
|
||||
counter.Increment(inc)
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) getRequestCheckService(consumer string) string {
|
||||
result := config.requestCheckService
|
||||
for _, obj := range config.consumerRequestCheckService {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if requestCheckService, ok := obj["requestCheckService"]; ok {
|
||||
result, _ = requestCheckService.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) getResponseCheckService(consumer string) string {
|
||||
result := config.responseCheckService
|
||||
for _, obj := range config.consumerResponseCheckService {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if responseCheckService, ok := obj["responseCheckService"]; ok {
|
||||
result, _ = responseCheckService.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) getRiskLevelBar(consumer string) string {
|
||||
result := config.riskLevelBar
|
||||
for _, obj := range config.consumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if riskLevelBar, ok := obj["riskLevelBar"]; ok {
|
||||
result, _ = riskLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) getContentModerationLevelBar(consumer string) string {
|
||||
result := config.contentModerationLevelBar
|
||||
for _, obj := range config.consumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if contentModerationLevelBar, ok := obj["contentModerationLevelBar"]; ok {
|
||||
result, _ = contentModerationLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) getPromptAttackLevelBar(consumer string) string {
|
||||
result := config.promptAttackLevelBar
|
||||
for _, obj := range config.consumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if promptAttackLevelBar, ok := obj["promptAttackLevelBar"]; ok {
|
||||
result, _ = promptAttackLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) getSensitiveDataLevelBar(consumer string) string {
|
||||
result := config.sensitiveDataLevelBar
|
||||
for _, obj := range config.consumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if sensitiveDataLevelBar, ok := obj["sensitiveDataLevelBar"]; ok {
|
||||
result, _ = sensitiveDataLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) getMaliciousUrlLevelBar(consumer string) string {
|
||||
result := config.maliciousUrlLevelBar
|
||||
for _, obj := range config.consumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if maliciousUrlLevelBar, ok := obj["maliciousUrlLevelBar"]; ok {
|
||||
result, _ = maliciousUrlLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) getModelHallucinationLevelBar(consumer string) string {
|
||||
result := config.modelHallucinationLevelBar
|
||||
for _, obj := range config.consumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if modelHallucinationLevelBar, ok := obj["modelHallucinationLevelBar"]; ok {
|
||||
result, _ = modelHallucinationLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func levelToInt(riskLevel string) int {
|
||||
// First check against our defined constants
|
||||
switch riskLevel {
|
||||
@@ -195,14 +339,14 @@ func levelToInt(riskLevel string) int {
|
||||
}
|
||||
}
|
||||
|
||||
func isRiskLevelAcceptable(action string, data Data, config AISecurityConfig) bool {
|
||||
func isRiskLevelAcceptable(action string, data Data, config AISecurityConfig, consumer string) bool {
|
||||
if action == "MultiModalGuard" {
|
||||
// Check top-level risk levels for MultiModalGuard
|
||||
if levelToInt(data.RiskLevel) >= levelToInt(config.contentModerationLevelBar) {
|
||||
if levelToInt(data.RiskLevel) >= levelToInt(config.getContentModerationLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
// Also check AttackLevel for prompt attack detection
|
||||
if levelToInt(data.AttackLevel) >= levelToInt(config.promptAttackLevelBar) {
|
||||
if levelToInt(data.AttackLevel) >= levelToInt(config.getPromptAttackLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -210,22 +354,30 @@ func isRiskLevelAcceptable(action string, data Data, config AISecurityConfig) bo
|
||||
for _, detail := range data.Detail {
|
||||
switch detail.Type {
|
||||
case ContentModerationType:
|
||||
if levelToInt(detail.Level) >= levelToInt(config.contentModerationLevelBar) {
|
||||
if levelToInt(detail.Level) >= levelToInt(config.getContentModerationLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
case PromptAttackType:
|
||||
if levelToInt(detail.Level) >= levelToInt(config.promptAttackLevelBar) {
|
||||
if levelToInt(detail.Level) >= levelToInt(config.getPromptAttackLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
case SensitiveDataType:
|
||||
if levelToInt(detail.Level) >= levelToInt(config.sensitiveDataLevelBar) {
|
||||
if levelToInt(detail.Level) >= levelToInt(config.getSensitiveDataLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
case MaliciousUrlDataType:
|
||||
if levelToInt(detail.Level) >= levelToInt(config.getMaliciousUrlLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
case ModelHallucinationDataType:
|
||||
if levelToInt(detail.Level) >= levelToInt(config.getModelHallucinationLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
} else {
|
||||
return levelToInt(data.RiskLevel) < levelToInt(config.riskLevelBar)
|
||||
return levelToInt(data.RiskLevel) < levelToInt(config.getRiskLevelBar(consumer))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -351,6 +503,22 @@ func parseConfig(json gjson.Result, config *AISecurityConfig) error {
|
||||
} else {
|
||||
config.sensitiveDataLevelBar = S4Sensitive
|
||||
}
|
||||
if obj := json.Get("modelHallucinationLevelBar"); obj.Exists() {
|
||||
config.modelHallucinationLevelBar = obj.String()
|
||||
if levelToInt(config.modelHallucinationLevelBar) <= 0 {
|
||||
return errors.New("invalid modelHallucinationLevelBar, value must be one of [max, high, medium, low]")
|
||||
}
|
||||
} else {
|
||||
config.modelHallucinationLevelBar = MaxRisk
|
||||
}
|
||||
if obj := json.Get("maliciousUrlLevelBar"); obj.Exists() {
|
||||
config.maliciousUrlLevelBar = obj.String()
|
||||
if levelToInt(config.maliciousUrlLevelBar) <= 0 {
|
||||
return errors.New("invalid maliciousUrlLevelBar, value must be one of [max, high, medium, low]")
|
||||
}
|
||||
} else {
|
||||
config.maliciousUrlLevelBar = MaxRisk
|
||||
}
|
||||
if obj := json.Get("timeout"); obj.Exists() {
|
||||
config.timeout = uint32(obj.Int())
|
||||
} else {
|
||||
@@ -361,13 +529,71 @@ func parseConfig(json gjson.Result, config *AISecurityConfig) error {
|
||||
} else {
|
||||
config.bufferLimit = 1000
|
||||
}
|
||||
config.consumerSpecificRequestCheckService = make(map[string]string)
|
||||
for k, v := range json.Get("consumerSpecificRequestCheckService").Map() {
|
||||
config.consumerSpecificRequestCheckService[k] = v.String()
|
||||
if obj := json.Get("consumerRequestCheckService"); obj.Exists() {
|
||||
for _, item := range json.Get("consumerRequestCheckService").Array() {
|
||||
m := make(map[string]interface{})
|
||||
for k, v := range item.Map() {
|
||||
m[k] = v.Value()
|
||||
}
|
||||
consumerName, ok1 := m["name"]
|
||||
matchType, ok2 := m["matchType"]
|
||||
if !ok1 || !ok2 {
|
||||
continue
|
||||
}
|
||||
switch fmt.Sprint(matchType) {
|
||||
case "exact":
|
||||
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
|
||||
case "prefix":
|
||||
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
|
||||
case "regexp":
|
||||
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
|
||||
}
|
||||
config.consumerRequestCheckService = append(config.consumerRequestCheckService, m)
|
||||
}
|
||||
}
|
||||
config.consumerSpecificResponseCheckService = make(map[string]string)
|
||||
for k, v := range json.Get("consumerSpecificResponseCheckService").Map() {
|
||||
config.consumerSpecificResponseCheckService[k] = v.String()
|
||||
if obj := json.Get("consumerResponseCheckService"); obj.Exists() {
|
||||
for _, item := range json.Get("consumerResponseCheckService").Array() {
|
||||
m := make(map[string]interface{})
|
||||
for k, v := range item.Map() {
|
||||
m[k] = v.Value()
|
||||
}
|
||||
consumerName, ok1 := m["name"]
|
||||
matchType, ok2 := m["matchType"]
|
||||
if !ok1 || !ok2 {
|
||||
continue
|
||||
}
|
||||
switch fmt.Sprint(matchType) {
|
||||
case "exact":
|
||||
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
|
||||
case "prefix":
|
||||
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
|
||||
case "regexp":
|
||||
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
|
||||
}
|
||||
config.consumerResponseCheckService = append(config.consumerResponseCheckService, m)
|
||||
}
|
||||
}
|
||||
if obj := json.Get("consumerRiskLevel"); obj.Exists() {
|
||||
for _, item := range json.Get("consumerRiskLevel").Array() {
|
||||
m := make(map[string]interface{})
|
||||
for k, v := range item.Map() {
|
||||
m[k] = v.Value()
|
||||
}
|
||||
consumerName, ok1 := m["name"]
|
||||
matchType, ok2 := m["matchType"]
|
||||
if !ok1 || !ok2 {
|
||||
continue
|
||||
}
|
||||
switch fmt.Sprint(matchType) {
|
||||
case "exact":
|
||||
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
|
||||
case "prefix":
|
||||
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
|
||||
case "regexp":
|
||||
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
|
||||
}
|
||||
config.consumerRiskLevel = append(config.consumerRiskLevel, m)
|
||||
}
|
||||
}
|
||||
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: serviceName,
|
||||
@@ -399,6 +625,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig) type
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
log.Debugf("checking request body...")
|
||||
startTime := time.Now().UnixMilli()
|
||||
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
|
||||
@@ -423,7 +650,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
if isRiskLevelAcceptable(config.action, response.Data, config) {
|
||||
if isRiskLevelAcceptable(config.action, response.Data, config, consumer) {
|
||||
if contentIndex >= len(content) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
@@ -441,7 +668,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := marshalStr(denyMessage)
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
if config.protocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
@@ -476,11 +703,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
||||
contentPiece := content[contentIndex:nextContentIndex]
|
||||
contentIndex = nextContentIndex
|
||||
log.Debugf("current content piece: %s", contentPiece)
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
checkService, ok := config.consumerSpecificRequestCheckService[consumer]
|
||||
if !ok {
|
||||
checkService = config.requestCheckService
|
||||
}
|
||||
checkService := config.getRequestCheckService(consumer)
|
||||
params := map[string]string{
|
||||
"Format": "JSON",
|
||||
"Version": "2022-03-02",
|
||||
@@ -491,7 +714,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
||||
"AccessKeyId": config.ak,
|
||||
"Timestamp": timestamp,
|
||||
"Service": checkService,
|
||||
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, marshalStr(contentPiece), AliyunUserAgent),
|
||||
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, wrapper.MarshalStr(contentPiece), AliyunUserAgent),
|
||||
}
|
||||
if config.token != "" {
|
||||
params["SecurityToken"] = config.token
|
||||
@@ -540,6 +763,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig) typ
|
||||
}
|
||||
|
||||
func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, data []byte, endOfStream bool) []byte {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
var bufferQueue [][]byte
|
||||
var singleCall func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
@@ -561,14 +785,14 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi
|
||||
ctx.SetContext("during_call", false)
|
||||
return
|
||||
}
|
||||
if !isRiskLevelAcceptable(config.action, response.Data, config) {
|
||||
if !isRiskLevelAcceptable(config.action, response.Data, config, consumer) {
|
||||
denyMessage := DefaultDenyMessage
|
||||
if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = "\n" + response.Data.Advice[0].Answer
|
||||
} else if config.denyMessage != "" {
|
||||
denyMessage = config.denyMessage
|
||||
}
|
||||
marshalledDenyMessage := marshalStr(denyMessage)
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
randomID := generateRandomID()
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.InjectEncodedDataToFilterChain(jsonData, true)
|
||||
@@ -587,7 +811,6 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi
|
||||
return
|
||||
}
|
||||
if ctx.BufferQueueSize() >= config.bufferLimit || ctx.GetContext("end_of_stream_received").(bool) {
|
||||
ctx.SetContext("during_call", true)
|
||||
var buffer string
|
||||
for ctx.BufferQueueSize() > 0 {
|
||||
front := ctx.PopBuffer()
|
||||
@@ -598,14 +821,16 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi
|
||||
break
|
||||
}
|
||||
}
|
||||
// if streaming body has reasoning_content, buffer maybe empty
|
||||
log.Debugf("current content piece: %s", buffer)
|
||||
if len(buffer) == 0 {
|
||||
return
|
||||
}
|
||||
ctx.SetContext("during_call", true)
|
||||
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
randomID, _ := generateHexID(16)
|
||||
log.Debugf("current content piece: %s", buffer)
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
checkService, ok := config.consumerSpecificResponseCheckService[consumer]
|
||||
if !ok {
|
||||
checkService = config.responseCheckService
|
||||
}
|
||||
checkService := config.getResponseCheckService(consumer)
|
||||
params := map[string]string{
|
||||
"Format": "JSON",
|
||||
"Version": "2022-03-02",
|
||||
@@ -616,7 +841,7 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi
|
||||
"AccessKeyId": config.ak,
|
||||
"Timestamp": timestamp,
|
||||
"Service": checkService,
|
||||
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, ctx.GetContext("sessionID").(string), marshalStr(buffer), AliyunUserAgent),
|
||||
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, ctx.GetContext("sessionID").(string), wrapper.MarshalStr(buffer), AliyunUserAgent),
|
||||
}
|
||||
if config.token != "" {
|
||||
params["SecurityToken"] = config.token
|
||||
@@ -637,7 +862,9 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi
|
||||
}
|
||||
}
|
||||
if !ctx.GetContext("risk_detected").(bool) {
|
||||
ctx.PushBuffer(data)
|
||||
for _, chunk := range bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) {
|
||||
ctx.PushBuffer([]byte(string(chunk) + "\n\n"))
|
||||
}
|
||||
ctx.SetContext("end_of_stream_received", endOfStream)
|
||||
if !ctx.GetContext("during_call").(bool) {
|
||||
singleCall()
|
||||
@@ -649,6 +876,7 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
log.Debugf("checking response body...")
|
||||
startTime := time.Now().UnixMilli()
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
@@ -680,7 +908,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
||||
proxywasm.ResumeHttpResponse()
|
||||
return
|
||||
}
|
||||
if isRiskLevelAcceptable(config.action, response.Data, config) {
|
||||
if isRiskLevelAcceptable(config.action, response.Data, config, consumer) {
|
||||
if contentIndex >= len(content) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
|
||||
@@ -698,7 +926,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := marshalStr(denyMessage)
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
if config.protocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
} else if isStreamingResponse {
|
||||
@@ -732,11 +960,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
||||
contentPiece := content[contentIndex:nextContentIndex]
|
||||
contentIndex = nextContentIndex
|
||||
log.Debugf("current content piece: %s", contentPiece)
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
checkService, ok := config.consumerSpecificResponseCheckService[consumer]
|
||||
if !ok {
|
||||
checkService = config.responseCheckService
|
||||
}
|
||||
checkService := config.getResponseCheckService(consumer)
|
||||
params := map[string]string{
|
||||
"Format": "JSON",
|
||||
"Version": "2022-03-02",
|
||||
@@ -747,7 +971,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
||||
"AccessKeyId": config.ak,
|
||||
"Timestamp": timestamp,
|
||||
"Service": checkService,
|
||||
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, marshalStr(contentPiece), AliyunUserAgent),
|
||||
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, wrapper.MarshalStr(contentPiece), AliyunUserAgent),
|
||||
}
|
||||
if config.token != "" {
|
||||
params["SecurityToken"] = config.token
|
||||
@@ -769,7 +993,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
||||
}
|
||||
|
||||
func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
|
||||
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
|
||||
chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n"))
|
||||
strChunks := []string{}
|
||||
for _, chunk := range chunks {
|
||||
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
|
||||
@@ -777,17 +1001,3 @@ func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
|
||||
}
|
||||
return strings.Join(strChunks, "")
|
||||
}
|
||||
|
||||
func marshalStr(raw string) string {
|
||||
helper := map[string]string{
|
||||
"placeholder": raw,
|
||||
}
|
||||
marshalledHelper, _ := json.Marshal(helper)
|
||||
marshalledRaw := gjson.GetBytes(marshalledHelper, "placeholder").Raw
|
||||
if len(marshalledRaw) >= 2 {
|
||||
return marshalledRaw[1 : len(marshalledRaw)-1]
|
||||
} else {
|
||||
log.Errorf("failed to marshal json string, raw string is: %s", raw)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,6 +96,42 @@ var missingAuthConfig = func() json.RawMessage {
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:消费者级别特殊配置
|
||||
var consumerSpecificConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"serviceName": "security-service",
|
||||
"servicePort": 8080,
|
||||
"serviceHost": "security.example.com",
|
||||
"accessKey": "test-ak",
|
||||
"secretKey": "test-sk",
|
||||
"checkRequest": true,
|
||||
"checkResponse": false,
|
||||
"contentModerationLevelBar": "high",
|
||||
"promptAttackLevelBar": "high",
|
||||
"sensitiveDataLevelBar": "S3",
|
||||
"maliciousUrlLevelBar": "high",
|
||||
"modelHallucinationLevelBar": "high",
|
||||
"timeout": 1000,
|
||||
"bufferLimit": 500,
|
||||
"consumerRequestCheckService": map[string]interface{}{
|
||||
"name": "aaa",
|
||||
"matchType": "exact",
|
||||
"requestCheckService": "llm_query_moderation_1",
|
||||
},
|
||||
"consumerResponseCheckService": map[string]interface{}{
|
||||
"name": "bbb",
|
||||
"matchType": "prefix",
|
||||
"responseCheckService": "llm_response_moderation_1",
|
||||
},
|
||||
"consumerRiskLevel": map[string]interface{}{
|
||||
"name": "ccc.*",
|
||||
"matchType": "regexp",
|
||||
"maliciousUrlLevelBar": "low",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试基础配置解析
|
||||
@@ -156,6 +192,24 @@ func TestParseConfig(t *testing.T) {
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||||
})
|
||||
|
||||
// 测试消费者级别配置
|
||||
t.Run("consumer specific config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(consumerSpecificConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
securityConfig := config.(*AISecurityConfig)
|
||||
require.Equal(t, "llm_query_moderation", securityConfig.getRequestCheckService("aaaa"))
|
||||
require.Equal(t, "llm_query_moderation_1", securityConfig.getRequestCheckService("aaa"))
|
||||
require.Equal(t, "llm_response_moderation", securityConfig.getResponseCheckService("bb"))
|
||||
require.Equal(t, "llm_response_moderation_1", securityConfig.getResponseCheckService("bbb-prefix-test"))
|
||||
require.Equal(t, "high", securityConfig.getMaliciousUrlLevelBar("cc"))
|
||||
require.Equal(t, "low", securityConfig.getMaliciousUrlLevelBar("ccc-regexp-test"))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -400,25 +454,3 @@ func TestUtilityFunctions(t *testing.T) {
|
||||
require.Len(t, id, 38) // "chatcmpl-" + 29 random chars
|
||||
})
|
||||
}
|
||||
|
||||
func TestMarshalFunctions(t *testing.T) {
|
||||
// 测试marshalStr函数
|
||||
t.Run("marshal string", func(t *testing.T) {
|
||||
testStr := "Hello, World!"
|
||||
marshalled := marshalStr(testStr)
|
||||
require.Equal(t, testStr, marshalled)
|
||||
})
|
||||
|
||||
// 测试extractMessageFromStreamingBody函数
|
||||
t.Run("extract streaming body", func(t *testing.T) {
|
||||
// 使用正确的分隔符,每个chunk之间用双换行符分隔
|
||||
streamingData := []byte(`{"choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"}}]}
|
||||
|
||||
{"choices":[{"index":0,"delta":{"role":"assistant","content":" World"}}]}
|
||||
|
||||
{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)
|
||||
|
||||
extracted := extractMessageFromStreamingBody(streamingData, "choices.0.delta.content")
|
||||
require.Equal(t, "Hello World", extracted)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user