diff --git a/plugins/wasm-go/extensions/ai-security-guard/README.md b/plugins/wasm-go/extensions/ai-security-guard/README.md index b02723e15..0d25d711a 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README.md @@ -37,8 +37,9 @@ description: 阿里云内容安全检测 | `sensitiveDataLevelBar` | string | optional | S4 | 敏感内容检测拦截风险等级,取值为 `S4`, `S3`, `S2` or `S1` | | `timeout` | int | optional | 2000 | 调用内容安全服务时的超时时间 | | `bufferLimit` | int | optional | 1000 | 调用内容安全服务时每段文本的长度限制 | -| `consumerSpecificRequestCheckService` | map | optional | - | 为不同消费者指定特定的请求检测服务 | -| `consumerSpecificResponseCheckService` | map | optional | - | 为不同消费者指定特定的响应检测服务 | +| `consumerRequestCheckService` | map | optional | - | 为不同消费者指定特定的请求检测服务 | +| `consumerResponseCheckService` | map | optional | - | 为不同消费者指定特定的响应检测服务 | +| `consumerRiskLevel` | map | optional | - | 为不同消费者指定各维度的拦截风险等级 | 补充说明一下 `denyMessage`,对非法请求的处理逻辑为: - 如果配置了 `denyMessage`,返回内容为 `denyMessage` 配置内容,格式为openai格式的流式/非流式响应 @@ -70,6 +71,20 @@ description: 阿里云内容安全检测 ![](https://img.alicdn.com/imgextra/i4/O1CN013AbDcn1slCY19inU2_!!6000000005806-0-tps-1754-1320.jpg) +阿里云内容安全配置示例: + +```yaml +requestCheckService: llm_query_moderation +responseCheckService: llm_response_moderation +``` + +阿里云AI安全护栏配置示例: + +```yaml +requestCheckService: query_security_check +responseCheckService: response_security_check +``` + ### 检测输入内容是否合规 ```yaml diff --git a/plugins/wasm-go/extensions/ai-security-guard/README_EN.md b/plugins/wasm-go/extensions/ai-security-guard/README_EN.md index 360893070..d3fe29e45 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README_EN.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README_EN.md @@ -37,6 +37,9 @@ Plugin Priority: `300` | `sensitiveDataLevelBar` | string | optional | S4 | sensitiveData risk level threshold, `S4`, `S3`, `S2` or `S1` | | `timeout` | int | optional | 2000 | timeout for lvwang service | | `bufferLimit` | int | optional | 1000 | Limit the length of each text when calling the lvwang service | +| `consumerRequestCheckService` | map | optional | - | Specify specific request detection services for different consumers | +| `consumerResponseCheckService` | map | optional | - | Specify specific response detection services for different consumers | +| `consumerRiskLevel` | map | optional | - | Specify interception risk levels for different consumers in different dimensions | ## Examples of configuration diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index 3e0c2fba4..1392db911 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -13,6 +13,7 @@ import ( mrand "math/rand" "net/http" "net/url" + "regexp" "sort" "strings" "time" @@ -51,9 +52,11 @@ const ( S1Sensitive = "S1" NoSensitive = "S0" - ContentModerationType = "contentModeration" - PromptAttackType = "promptAttack" - SensitiveDataType = "sensitiveData" + ContentModerationType = "contentModeration" + PromptAttackType = "promptAttack" + SensitiveDataType = "sensitiveData" + MaliciousUrlDataType = "maliciousUrl" + ModelHallucinationDataType = "modelHallucination" OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}` @@ -108,30 +111,51 @@ type Detail struct { } type AISecurityConfig struct { - client wrapper.HttpClient - ak string - sk string - token string - action string - checkRequest bool - requestCheckService string - requestContentJsonPath string - checkResponse bool - responseCheckService string - responseContentJsonPath string - responseStreamContentJsonPath string - denyCode int64 - denyMessage string - protocolOriginal bool - riskLevelBar string - contentModerationLevelBar string - promptAttackLevelBar string - sensitiveDataLevelBar string - timeout uint32 - bufferLimit int - metrics map[string]proxywasm.MetricCounter - consumerSpecificRequestCheckService map[string]string - consumerSpecificResponseCheckService map[string]string + client wrapper.HttpClient + ak string + sk string + token string + action string + checkRequest bool + requestCheckService string + requestContentJsonPath string + checkResponse bool + responseCheckService string + responseContentJsonPath string + responseStreamContentJsonPath string + denyCode int64 + denyMessage string + protocolOriginal bool + riskLevelBar string + contentModerationLevelBar string + promptAttackLevelBar string + sensitiveDataLevelBar string + maliciousUrlLevelBar string + modelHallucinationLevelBar string + timeout uint32 + bufferLimit int + metrics map[string]proxywasm.MetricCounter + consumerRequestCheckService []map[string]interface{} + consumerResponseCheckService []map[string]interface{} + consumerRiskLevel []map[string]interface{} +} + +type Matcher struct { + Exact string + Prefix string + Re *regexp.Regexp +} + +func (m *Matcher) match(consumer string) bool { + if m.Exact != "" { + return consumer == m.Exact + } else if m.Prefix != "" { + return strings.HasPrefix(consumer, m.Prefix) + } else if m.Re != nil { + return m.Re.MatchString(consumer) + } else { + return false + } } func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) { @@ -143,6 +167,126 @@ func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) counter.Increment(inc) } +func (config *AISecurityConfig) getRequestCheckService(consumer string) string { + result := config.requestCheckService + for _, obj := range config.consumerRequestCheckService { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if requestCheckService, ok := obj["requestCheckService"]; ok { + result, _ = requestCheckService.(string) + } + break + } + } + } + return result +} + +func (config *AISecurityConfig) getResponseCheckService(consumer string) string { + result := config.responseCheckService + for _, obj := range config.consumerResponseCheckService { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if responseCheckService, ok := obj["responseCheckService"]; ok { + result, _ = responseCheckService.(string) + } + break + } + } + } + return result +} + +func (config *AISecurityConfig) getRiskLevelBar(consumer string) string { + result := config.riskLevelBar + for _, obj := range config.consumerRiskLevel { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if riskLevelBar, ok := obj["riskLevelBar"]; ok { + result, _ = riskLevelBar.(string) + } + break + } + } + } + return result +} + +func (config *AISecurityConfig) getContentModerationLevelBar(consumer string) string { + result := config.contentModerationLevelBar + for _, obj := range config.consumerRiskLevel { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if contentModerationLevelBar, ok := obj["contentModerationLevelBar"]; ok { + result, _ = contentModerationLevelBar.(string) + } + break + } + } + } + return result +} + +func (config *AISecurityConfig) getPromptAttackLevelBar(consumer string) string { + result := config.promptAttackLevelBar + for _, obj := range config.consumerRiskLevel { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if promptAttackLevelBar, ok := obj["promptAttackLevelBar"]; ok { + result, _ = promptAttackLevelBar.(string) + } + break + } + } + } + return result +} + +func (config *AISecurityConfig) getSensitiveDataLevelBar(consumer string) string { + result := config.sensitiveDataLevelBar + for _, obj := range config.consumerRiskLevel { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if sensitiveDataLevelBar, ok := obj["sensitiveDataLevelBar"]; ok { + result, _ = sensitiveDataLevelBar.(string) + } + break + } + } + } + return result +} + +func (config *AISecurityConfig) getMaliciousUrlLevelBar(consumer string) string { + result := config.maliciousUrlLevelBar + for _, obj := range config.consumerRiskLevel { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if maliciousUrlLevelBar, ok := obj["maliciousUrlLevelBar"]; ok { + result, _ = maliciousUrlLevelBar.(string) + } + break + } + } + } + return result +} + +func (config *AISecurityConfig) getModelHallucinationLevelBar(consumer string) string { + result := config.modelHallucinationLevelBar + for _, obj := range config.consumerRiskLevel { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if modelHallucinationLevelBar, ok := obj["modelHallucinationLevelBar"]; ok { + result, _ = modelHallucinationLevelBar.(string) + } + break + } + } + } + return result +} + func levelToInt(riskLevel string) int { // First check against our defined constants switch riskLevel { @@ -195,14 +339,14 @@ func levelToInt(riskLevel string) int { } } -func isRiskLevelAcceptable(action string, data Data, config AISecurityConfig) bool { +func isRiskLevelAcceptable(action string, data Data, config AISecurityConfig, consumer string) bool { if action == "MultiModalGuard" { // Check top-level risk levels for MultiModalGuard - if levelToInt(data.RiskLevel) >= levelToInt(config.contentModerationLevelBar) { + if levelToInt(data.RiskLevel) >= levelToInt(config.getContentModerationLevelBar(consumer)) { return false } // Also check AttackLevel for prompt attack detection - if levelToInt(data.AttackLevel) >= levelToInt(config.promptAttackLevelBar) { + if levelToInt(data.AttackLevel) >= levelToInt(config.getPromptAttackLevelBar(consumer)) { return false } @@ -210,22 +354,30 @@ func isRiskLevelAcceptable(action string, data Data, config AISecurityConfig) bo for _, detail := range data.Detail { switch detail.Type { case ContentModerationType: - if levelToInt(detail.Level) >= levelToInt(config.contentModerationLevelBar) { + if levelToInt(detail.Level) >= levelToInt(config.getContentModerationLevelBar(consumer)) { return false } case PromptAttackType: - if levelToInt(detail.Level) >= levelToInt(config.promptAttackLevelBar) { + if levelToInt(detail.Level) >= levelToInt(config.getPromptAttackLevelBar(consumer)) { return false } case SensitiveDataType: - if levelToInt(detail.Level) >= levelToInt(config.sensitiveDataLevelBar) { + if levelToInt(detail.Level) >= levelToInt(config.getSensitiveDataLevelBar(consumer)) { + return false + } + case MaliciousUrlDataType: + if levelToInt(detail.Level) >= levelToInt(config.getMaliciousUrlLevelBar(consumer)) { + return false + } + case ModelHallucinationDataType: + if levelToInt(detail.Level) >= levelToInt(config.getModelHallucinationLevelBar(consumer)) { return false } } } return true } else { - return levelToInt(data.RiskLevel) < levelToInt(config.riskLevelBar) + return levelToInt(data.RiskLevel) < levelToInt(config.getRiskLevelBar(consumer)) } } @@ -351,6 +503,22 @@ func parseConfig(json gjson.Result, config *AISecurityConfig) error { } else { config.sensitiveDataLevelBar = S4Sensitive } + if obj := json.Get("modelHallucinationLevelBar"); obj.Exists() { + config.modelHallucinationLevelBar = obj.String() + if levelToInt(config.modelHallucinationLevelBar) <= 0 { + return errors.New("invalid modelHallucinationLevelBar, value must be one of [max, high, medium, low]") + } + } else { + config.modelHallucinationLevelBar = MaxRisk + } + if obj := json.Get("maliciousUrlLevelBar"); obj.Exists() { + config.maliciousUrlLevelBar = obj.String() + if levelToInt(config.maliciousUrlLevelBar) <= 0 { + return errors.New("invalid maliciousUrlLevelBar, value must be one of [max, high, medium, low]") + } + } else { + config.maliciousUrlLevelBar = MaxRisk + } if obj := json.Get("timeout"); obj.Exists() { config.timeout = uint32(obj.Int()) } else { @@ -361,13 +529,71 @@ func parseConfig(json gjson.Result, config *AISecurityConfig) error { } else { config.bufferLimit = 1000 } - config.consumerSpecificRequestCheckService = make(map[string]string) - for k, v := range json.Get("consumerSpecificRequestCheckService").Map() { - config.consumerSpecificRequestCheckService[k] = v.String() + if obj := json.Get("consumerRequestCheckService"); obj.Exists() { + for _, item := range json.Get("consumerRequestCheckService").Array() { + m := make(map[string]interface{}) + for k, v := range item.Map() { + m[k] = v.Value() + } + consumerName, ok1 := m["name"] + matchType, ok2 := m["matchType"] + if !ok1 || !ok2 { + continue + } + switch fmt.Sprint(matchType) { + case "exact": + m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)} + case "prefix": + m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)} + case "regexp": + m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))} + } + config.consumerRequestCheckService = append(config.consumerRequestCheckService, m) + } } - config.consumerSpecificResponseCheckService = make(map[string]string) - for k, v := range json.Get("consumerSpecificResponseCheckService").Map() { - config.consumerSpecificResponseCheckService[k] = v.String() + if obj := json.Get("consumerResponseCheckService"); obj.Exists() { + for _, item := range json.Get("consumerResponseCheckService").Array() { + m := make(map[string]interface{}) + for k, v := range item.Map() { + m[k] = v.Value() + } + consumerName, ok1 := m["name"] + matchType, ok2 := m["matchType"] + if !ok1 || !ok2 { + continue + } + switch fmt.Sprint(matchType) { + case "exact": + m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)} + case "prefix": + m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)} + case "regexp": + m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))} + } + config.consumerResponseCheckService = append(config.consumerResponseCheckService, m) + } + } + if obj := json.Get("consumerRiskLevel"); obj.Exists() { + for _, item := range json.Get("consumerRiskLevel").Array() { + m := make(map[string]interface{}) + for k, v := range item.Map() { + m[k] = v.Value() + } + consumerName, ok1 := m["name"] + matchType, ok2 := m["matchType"] + if !ok1 || !ok2 { + continue + } + switch fmt.Sprint(matchType) { + case "exact": + m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)} + case "prefix": + m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)} + case "regexp": + m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))} + } + config.consumerRiskLevel = append(config.consumerRiskLevel, m) + } } config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{ FQDN: serviceName, @@ -399,6 +625,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig) type } func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte) types.Action { + consumer, _ := ctx.GetContext("consumer").(string) log.Debugf("checking request body...") startTime := time.Now().UnixMilli() content := gjson.GetBytes(body, config.requestContentJsonPath).String() @@ -423,7 +650,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] proxywasm.ResumeHttpRequest() return } - if isRiskLevelAcceptable(config.action, response.Data, config) { + if isRiskLevelAcceptable(config.action, response.Data, config, consumer) { if contentIndex >= len(content) { endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) @@ -441,7 +668,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] } else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { denyMessage = response.Data.Advice[0].Answer } - marshalledDenyMessage := marshalStr(denyMessage) + marshalledDenyMessage := wrapper.MarshalStr(denyMessage) if config.protocolOriginal { proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) } else if gjson.GetBytes(body, "stream").Bool() { @@ -476,11 +703,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] contentPiece := content[contentIndex:nextContentIndex] contentIndex = nextContentIndex log.Debugf("current content piece: %s", contentPiece) - consumer, _ := ctx.GetContext("consumer").(string) - checkService, ok := config.consumerSpecificRequestCheckService[consumer] - if !ok { - checkService = config.requestCheckService - } + checkService := config.getRequestCheckService(consumer) params := map[string]string{ "Format": "JSON", "Version": "2022-03-02", @@ -491,7 +714,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] "AccessKeyId": config.ak, "Timestamp": timestamp, "Service": checkService, - "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, marshalStr(contentPiece), AliyunUserAgent), + "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, wrapper.MarshalStr(contentPiece), AliyunUserAgent), } if config.token != "" { params["SecurityToken"] = config.token @@ -540,6 +763,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig) typ } func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, data []byte, endOfStream bool) []byte { + consumer, _ := ctx.GetContext("consumer").(string) var bufferQueue [][]byte var singleCall func() callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { @@ -561,14 +785,14 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi ctx.SetContext("during_call", false) return } - if !isRiskLevelAcceptable(config.action, response.Data, config) { + if !isRiskLevelAcceptable(config.action, response.Data, config, consumer) { denyMessage := DefaultDenyMessage if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { denyMessage = "\n" + response.Data.Advice[0].Answer } else if config.denyMessage != "" { denyMessage = config.denyMessage } - marshalledDenyMessage := marshalStr(denyMessage) + marshalledDenyMessage := wrapper.MarshalStr(denyMessage) randomID := generateRandomID() jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) proxywasm.InjectEncodedDataToFilterChain(jsonData, true) @@ -587,7 +811,6 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi return } if ctx.BufferQueueSize() >= config.bufferLimit || ctx.GetContext("end_of_stream_received").(bool) { - ctx.SetContext("during_call", true) var buffer string for ctx.BufferQueueSize() > 0 { front := ctx.PopBuffer() @@ -598,14 +821,16 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi break } } + // if streaming body has reasoning_content, buffer maybe empty + log.Debugf("current content piece: %s", buffer) + if len(buffer) == 0 { + return + } + ctx.SetContext("during_call", true) timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") randomID, _ := generateHexID(16) log.Debugf("current content piece: %s", buffer) - consumer, _ := ctx.GetContext("consumer").(string) - checkService, ok := config.consumerSpecificResponseCheckService[consumer] - if !ok { - checkService = config.responseCheckService - } + checkService := config.getResponseCheckService(consumer) params := map[string]string{ "Format": "JSON", "Version": "2022-03-02", @@ -616,7 +841,7 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi "AccessKeyId": config.ak, "Timestamp": timestamp, "Service": checkService, - "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, ctx.GetContext("sessionID").(string), marshalStr(buffer), AliyunUserAgent), + "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, ctx.GetContext("sessionID").(string), wrapper.MarshalStr(buffer), AliyunUserAgent), } if config.token != "" { params["SecurityToken"] = config.token @@ -637,7 +862,9 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi } } if !ctx.GetContext("risk_detected").(bool) { - ctx.PushBuffer(data) + for _, chunk := range bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) { + ctx.PushBuffer([]byte(string(chunk) + "\n\n")) + } ctx.SetContext("end_of_stream_received", endOfStream) if !ctx.GetContext("during_call").(bool) { singleCall() @@ -649,6 +876,7 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi } func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte) types.Action { + consumer, _ := ctx.GetContext("consumer").(string) log.Debugf("checking response body...") startTime := time.Now().UnixMilli() contentType, _ := proxywasm.GetHttpResponseHeader("content-type") @@ -680,7 +908,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ proxywasm.ResumeHttpResponse() return } - if isRiskLevelAcceptable(config.action, response.Data, config) { + if isRiskLevelAcceptable(config.action, response.Data, config, consumer) { if contentIndex >= len(content) { endTime := time.Now().UnixMilli() ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) @@ -698,7 +926,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ } else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { denyMessage = response.Data.Advice[0].Answer } - marshalledDenyMessage := marshalStr(denyMessage) + marshalledDenyMessage := wrapper.MarshalStr(denyMessage) if config.protocolOriginal { proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) } else if isStreamingResponse { @@ -732,11 +960,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ contentPiece := content[contentIndex:nextContentIndex] contentIndex = nextContentIndex log.Debugf("current content piece: %s", contentPiece) - consumer, _ := ctx.GetContext("consumer").(string) - checkService, ok := config.consumerSpecificResponseCheckService[consumer] - if !ok { - checkService = config.responseCheckService - } + checkService := config.getResponseCheckService(consumer) params := map[string]string{ "Format": "JSON", "Version": "2022-03-02", @@ -747,7 +971,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ "AccessKeyId": config.ak, "Timestamp": timestamp, "Service": checkService, - "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, marshalStr(contentPiece), AliyunUserAgent), + "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, wrapper.MarshalStr(contentPiece), AliyunUserAgent), } if config.token != "" { params["SecurityToken"] = config.token @@ -769,7 +993,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ } func extractMessageFromStreamingBody(data []byte, jsonPath string) string { - chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n")) + chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) strChunks := []string{} for _, chunk := range chunks { // Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}] @@ -777,17 +1001,3 @@ func extractMessageFromStreamingBody(data []byte, jsonPath string) string { } return strings.Join(strChunks, "") } - -func marshalStr(raw string) string { - helper := map[string]string{ - "placeholder": raw, - } - marshalledHelper, _ := json.Marshal(helper) - marshalledRaw := gjson.GetBytes(marshalledHelper, "placeholder").Raw - if len(marshalledRaw) >= 2 { - return marshalledRaw[1 : len(marshalledRaw)-1] - } else { - log.Errorf("failed to marshal json string, raw string is: %s", raw) - return "" - } -} diff --git a/plugins/wasm-go/extensions/ai-security-guard/main_test.go b/plugins/wasm-go/extensions/ai-security-guard/main_test.go index 5ce57d9a5..18cdece64 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main_test.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main_test.go @@ -96,6 +96,42 @@ var missingAuthConfig = func() json.RawMessage { return data }() +// 测试配置:消费者级别特殊配置 +var consumerSpecificConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "checkResponse": false, + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "sensitiveDataLevelBar": "S3", + "maliciousUrlLevelBar": "high", + "modelHallucinationLevelBar": "high", + "timeout": 1000, + "bufferLimit": 500, + "consumerRequestCheckService": map[string]interface{}{ + "name": "aaa", + "matchType": "exact", + "requestCheckService": "llm_query_moderation_1", + }, + "consumerResponseCheckService": map[string]interface{}{ + "name": "bbb", + "matchType": "prefix", + "responseCheckService": "llm_response_moderation_1", + }, + "consumerRiskLevel": map[string]interface{}{ + "name": "ccc.*", + "matchType": "regexp", + "maliciousUrlLevelBar": "low", + }, + }) + return data +}() + func TestParseConfig(t *testing.T) { test.RunGoTest(t, func(t *testing.T) { // 测试基础配置解析 @@ -156,6 +192,24 @@ func TestParseConfig(t *testing.T) { defer host.Reset() require.Equal(t, types.OnPluginStartStatusFailed, status) }) + + // 测试消费者级别配置 + t.Run("consumer specific config", func(t *testing.T) { + host, status := test.NewTestHost(consumerSpecificConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + + securityConfig := config.(*AISecurityConfig) + require.Equal(t, "llm_query_moderation", securityConfig.getRequestCheckService("aaaa")) + require.Equal(t, "llm_query_moderation_1", securityConfig.getRequestCheckService("aaa")) + require.Equal(t, "llm_response_moderation", securityConfig.getResponseCheckService("bb")) + require.Equal(t, "llm_response_moderation_1", securityConfig.getResponseCheckService("bbb-prefix-test")) + require.Equal(t, "high", securityConfig.getMaliciousUrlLevelBar("cc")) + require.Equal(t, "low", securityConfig.getMaliciousUrlLevelBar("ccc-regexp-test")) + }) }) } @@ -400,25 +454,3 @@ func TestUtilityFunctions(t *testing.T) { require.Len(t, id, 38) // "chatcmpl-" + 29 random chars }) } - -func TestMarshalFunctions(t *testing.T) { - // 测试marshalStr函数 - t.Run("marshal string", func(t *testing.T) { - testStr := "Hello, World!" - marshalled := marshalStr(testStr) - require.Equal(t, testStr, marshalled) - }) - - // 测试extractMessageFromStreamingBody函数 - t.Run("extract streaming body", func(t *testing.T) { - // 使用正确的分隔符,每个chunk之间用双换行符分隔 - streamingData := []byte(`{"choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"}}]} - -{"choices":[{"index":0,"delta":{"role":"assistant","content":" World"}}]} - -{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`) - - extracted := extractMessageFromStreamingBody(streamingData, "choices.0.delta.content") - require.Equal(t, "Hello World", extracted) - }) -}