Files
higress/plugins/wasm-go/extensions/ai-security-guard/config/config.go

935 lines
31 KiB
Go

package config
import (
"encoding/json"
"errors"
"fmt"
"regexp"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
MaxRisk = "max"
HighRisk = "high"
MediumRisk = "medium"
LowRisk = "low"
NoRisk = "none"
S4Sensitive = "s4"
S3Sensitive = "s3"
S2Sensitive = "s2"
S1Sensitive = "s1"
NoSensitive = "s0"
ContentModerationType = "contentModeration"
PromptAttackType = "promptAttack"
SensitiveDataType = "sensitiveData"
MaliciousUrlDataType = "maliciousUrl"
ModelHallucinationDataType = "modelHallucination"
CustomLabelType = "customLabel"
MaliciousFileType = "maliciousFile"
WaterMarkType = "waterMark"
// Default configurations
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
DefaultDenyCode = 200
DefaultDenyMessage = "很抱歉,我无法回答您的问题"
DefaultTimeout = 2000
AliyunUserAgent = "CIPFrom/AIGateway"
LengthLimit = 1800
DefaultRequestCheckService = "llm_query_moderation"
DefaultResponseCheckService = "llm_response_moderation"
DefaultRequestJsonPath = "messages.@reverse.0.content"
DefaultResponseJsonPath = "choices.0.message.content"
DefaultStreamingResponseJsonPath = "choices.0.delta.content"
// Actions
MultiModalGuard = "MultiModalGuard"
MultiModalGuardForBase64 = "MultiModalGuardForBase64"
TextModerationPlus = "TextModerationPlus"
// Services
DefaultMultiModalGuardTextInputCheckService = "query_security_check"
DefaultMultiModalGuardTextOutputCheckService = "response_security_check"
DefaultMultiModalGuardImageInputCheckService = "img_query_security_check"
DefaultTextModerationPlusTextInputCheckService = "llm_query_moderation"
DefaultTextModerationPlusTextOutputCheckService = "llm_response_moderation"
)
// api types
const (
ApiTextGeneration = "text_generation"
ApiImageGeneration = "image_generation"
ApiMCP = "mcp"
)
// provider types
const (
ProviderOpenAI = "openai"
ProviderQwen = "qwen"
ProviderComfyUI = "comfyui"
)
type Response struct {
Code int `json:"Code"`
Message string `json:"Message"`
RequestId string `json:"RequestId"`
Data Data `json:"Data"`
}
type Data struct {
RiskLevel string `json:"RiskLevel,omitempty"`
AttackLevel string `json:"AttackLevel,omitempty"`
Suggestion string `json:"Suggestion,omitempty"`
Result []Result `json:"Result,omitempty"`
Advice []Advice `json:"Advice,omitempty"`
Detail []Detail `json:"Detail,omitempty"`
}
type Ext struct {
Desensitization string `json:"Desensitization,omitempty"`
SensitiveData []string `json:"SensitiveData,omitempty"`
}
type Result struct {
RiskWords string `json:"RiskWords,omitempty"`
Description string `json:"Description,omitempty"`
Confidence float64 `json:"Confidence,omitempty"`
Label string `json:"Label,omitempty"`
Ext Ext `json:"Ext,omitempty"`
}
type Advice struct {
Answer string `json:"Answer,omitempty"`
HitLabel string `json:"HitLabel,omitempty"`
HitLibName string `json:"HitLibName,omitempty"`
}
type Detail struct {
Suggestion string `json:"Suggestion,omitempty"`
Type string `json:"Type,omitempty"`
Level string `json:"Level,omitempty"`
Result []Result `json:"Result,omitempty"`
}
type Matcher struct {
Exact string
Prefix string
Re *regexp.Regexp
}
func (m *Matcher) match(consumer string) bool {
if m.Exact != "" {
return consumer == m.Exact
} else if m.Prefix != "" {
return strings.HasPrefix(consumer, m.Prefix)
} else if m.Re != nil {
return m.Re.MatchString(consumer)
} else {
return false
}
}
type AISecurityConfig struct {
Client wrapper.HttpClient
Host string
AK string
SK string
Token string
Action string
CheckRequest bool
CheckRequestImage bool
RequestCheckService string
RequestImageCheckService string
RequestContentJsonPath string
CheckResponse bool
ResponseCheckService string
ResponseImageCheckService string
ResponseContentJsonPath string
ResponseStreamContentJsonPath string
DenyCode int64
DenyMessage string
ProtocolOriginal bool
RiskLevelBar string
ContentModerationLevelBar string
PromptAttackLevelBar string
SensitiveDataLevelBar string
MaliciousUrlLevelBar string
ModelHallucinationLevelBar string
CustomLabelLevelBar string
Timeout uint32
BufferLimit int
Metrics map[string]proxywasm.MetricCounter
ConsumerRequestCheckService []map[string]interface{}
ConsumerResponseCheckService []map[string]interface{}
ConsumerRiskLevel []map[string]interface{}
// text_generation, image_generation, etc.
ApiType string
// openai, qwen, comfyui, etc.
ProviderType string
// "block" or "mask", default "block"
RiskAction string
// Dimension-level action fields (optional, empty string means not configured)
ContentModerationAction string
PromptAttackAction string
SensitiveDataAction string
MaliciousUrlAction string
ModelHallucinationAction string
CustomLabelAction string
}
func (config *AISecurityConfig) Parse(json gjson.Result) error {
serviceName := json.Get("serviceName").String()
servicePort := json.Get("servicePort").Int()
serviceHost := json.Get("serviceHost").String()
config.Host = serviceHost
if serviceName == "" || servicePort == 0 || serviceHost == "" {
return errors.New("invalid service config")
}
config.AK = json.Get("accessKey").String()
config.SK = json.Get("secretKey").String()
if config.AK == "" || config.SK == "" {
return errors.New("invalid AK/SK config")
}
config.Token = json.Get("securityToken").String()
// set action
if obj := json.Get("action"); obj.Exists() {
config.Action = json.Get("action").String()
} else {
config.Action = TextModerationPlus
}
// set default values
config.SetDefaultValues()
// set riskAction
if obj := json.Get("riskAction"); obj.Exists() {
config.RiskAction = obj.String()
if config.RiskAction != "block" && config.RiskAction != "mask" {
return errors.New("invalid riskAction, value must be one of [block, mask]")
}
}
// parse global dimension action fields
isMultiModalGuard := config.Action == MultiModalGuard || config.Action == MultiModalGuardForBase64
dimensionActionFields := []struct {
fieldName string
target *string
}{
{"contentModerationAction", &config.ContentModerationAction},
{"promptAttackAction", &config.PromptAttackAction},
{"sensitiveDataAction", &config.SensitiveDataAction},
{"maliciousUrlAction", &config.MaliciousUrlAction},
{"modelHallucinationAction", &config.ModelHallucinationAction},
{"customLabelAction", &config.CustomLabelAction},
}
hasDimensionAction := false
for _, field := range dimensionActionFields {
if isMultiModalGuard {
val, err := parseDimensionAction(json, field.fieldName)
if err != nil {
return err
}
*field.target = val
if val != "" {
hasDimensionAction = true
}
} else {
// Non-MultiModalGuard: read value without validation, field will be ignored at runtime
if obj := json.Get(field.fieldName); obj.Exists() {
*field.target = obj.String()
hasDimensionAction = true
}
}
}
if hasDimensionAction && !isMultiModalGuard {
proxywasm.LogWarnf("dimension action fields are configured but will be ignored because action is %s (not MultiModalGuard/MultiModalGuardForBase64)", config.Action)
}
// set values
if obj := json.Get("riskLevelBar"); obj.Exists() {
config.RiskLevelBar = obj.String()
}
if obj := json.Get("requestCheckService"); obj.Exists() {
config.RequestCheckService = obj.String()
}
if obj := json.Get("requestImageCheckService"); obj.Exists() {
config.RequestImageCheckService = obj.String()
}
if obj := json.Get("responseCheckService"); obj.Exists() {
config.ResponseCheckService = obj.String()
}
if obj := json.Get("responseImageCheckService"); obj.Exists() {
config.ResponseImageCheckService = obj.String()
}
config.CheckRequest = json.Get("checkRequest").Bool()
config.CheckRequestImage = json.Get("checkRequestImage").Bool()
config.CheckResponse = json.Get("checkResponse").Bool()
config.ProtocolOriginal = json.Get("protocol").String() == "original"
config.DenyMessage = json.Get("denyMessage").String()
if obj := json.Get("denyCode"); obj.Exists() {
config.DenyCode = obj.Int()
}
if obj := json.Get("requestContentJsonPath"); obj.Exists() {
config.RequestContentJsonPath = obj.String()
}
if obj := json.Get("responseContentJsonPath"); obj.Exists() {
config.ResponseContentJsonPath = obj.String()
}
if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() {
config.ResponseStreamContentJsonPath = obj.String()
}
if obj := json.Get("contentModerationLevelBar"); obj.Exists() {
config.ContentModerationLevelBar = obj.String()
if LevelToInt(config.ContentModerationLevelBar) <= 0 {
return errors.New("invalid contentModerationLevelBar, value must be one of [max, high, medium, low]")
}
}
if obj := json.Get("promptAttackLevelBar"); obj.Exists() {
config.PromptAttackLevelBar = obj.String()
if LevelToInt(config.PromptAttackLevelBar) <= 0 {
return errors.New("invalid promptAttackLevelBar, value must be one of [max, high, medium, low]")
}
}
if obj := json.Get("sensitiveDataLevelBar"); obj.Exists() {
config.SensitiveDataLevelBar = obj.String()
if LevelToInt(config.SensitiveDataLevelBar) <= 0 {
return errors.New("invalid sensitiveDataLevelBar, value must be one of [S4, S3, S2, S1]")
}
}
if obj := json.Get("modelHallucinationLevelBar"); obj.Exists() {
config.ModelHallucinationLevelBar = obj.String()
if LevelToInt(config.ModelHallucinationLevelBar) <= 0 {
return errors.New("invalid modelHallucinationLevelBar, value must be one of [max, high, medium, low]")
}
}
if obj := json.Get("maliciousUrlLevelBar"); obj.Exists() {
config.MaliciousUrlLevelBar = obj.String()
if LevelToInt(config.MaliciousUrlLevelBar) <= 0 {
return errors.New("invalid maliciousUrlLevelBar, value must be one of [max, high, medium, low]")
}
}
if obj := json.Get("customLabelLevelBar"); obj.Exists() {
config.CustomLabelLevelBar = obj.String()
if LevelToInt(config.CustomLabelLevelBar) <= 0 {
return errors.New("invalid customLabelLevelBar, value must be one of [max, high, medium, low]")
}
}
if obj := json.Get("timeout"); obj.Exists() {
config.Timeout = uint32(obj.Int())
}
if obj := json.Get("bufferLimit"); obj.Exists() {
config.BufferLimit = int(obj.Int())
}
if obj := json.Get("consumerRequestCheckService"); obj.Exists() {
for _, item := range json.Get("consumerRequestCheckService").Array() {
m := make(map[string]interface{})
for k, v := range item.Map() {
m[k] = v.Value()
}
consumerName, ok1 := m["name"]
matchType, ok2 := m["matchType"]
if !ok1 || !ok2 {
continue
}
switch fmt.Sprint(matchType) {
case "exact":
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
case "prefix":
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
case "regexp":
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
}
config.ConsumerRequestCheckService = append(config.ConsumerRequestCheckService, m)
}
}
if obj := json.Get("consumerResponseCheckService"); obj.Exists() {
for _, item := range json.Get("consumerResponseCheckService").Array() {
m := make(map[string]interface{})
for k, v := range item.Map() {
m[k] = v.Value()
}
consumerName, ok1 := m["name"]
matchType, ok2 := m["matchType"]
if !ok1 || !ok2 {
continue
}
switch fmt.Sprint(matchType) {
case "exact":
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
case "prefix":
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
case "regexp":
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
}
config.ConsumerResponseCheckService = append(config.ConsumerResponseCheckService, m)
}
}
if obj := json.Get("consumerRiskLevel"); obj.Exists() {
for _, item := range json.Get("consumerRiskLevel").Array() {
m := make(map[string]interface{})
for k, v := range item.Map() {
m[k] = v.Value()
}
consumerName, ok1 := m["name"]
matchType, ok2 := m["matchType"]
if !ok1 || !ok2 {
continue
}
switch fmt.Sprint(matchType) {
case "exact":
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
case "prefix":
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
case "regexp":
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
}
if ra, ok := m["riskAction"]; ok {
raStr := fmt.Sprint(ra)
if raStr != "block" && raStr != "mask" {
return errors.New("invalid riskAction in consumerRiskLevel, value must be one of [block, mask]")
}
}
// Validate dimension action fields in consumer risk level
if isMultiModalGuard {
consumerDimensionActionFields := []string{
"contentModerationAction",
"promptAttackAction",
"sensitiveDataAction",
"maliciousUrlAction",
"modelHallucinationAction",
"customLabelAction",
}
for _, fieldName := range consumerDimensionActionFields {
if v, ok := m[fieldName]; ok {
vStr := fmt.Sprint(v)
if vStr != "block" && vStr != "mask" {
return fmt.Errorf("invalid %s in consumerRiskLevel, value must be one of [block, mask]", fieldName)
}
}
}
}
config.ConsumerRiskLevel = append(config.ConsumerRiskLevel, m)
}
}
if obj := json.Get("apiType"); obj.Exists() {
config.ApiType = obj.String()
}
if obj := json.Get("providerType"); obj.Exists() {
config.ProviderType = obj.String()
}
config.Client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: serviceName,
Port: servicePort,
Host: serviceHost,
})
config.Metrics = make(map[string]proxywasm.MetricCounter)
return nil
}
// parseDimensionAction parses a dimension action field from JSON config.
// Returns the value if valid (block/mask), empty string if not present, or error if invalid.
func parseDimensionAction(json gjson.Result, fieldName string) (string, error) {
if obj := json.Get(fieldName); obj.Exists() {
val := obj.String()
if val != "block" && val != "mask" {
return "", fmt.Errorf("invalid %s, value must be one of [block, mask]", fieldName)
}
return val, nil
}
return "", nil
}
func (config *AISecurityConfig) SetDefaultValues() {
switch config.Action {
case TextModerationPlus:
config.RequestCheckService = DefaultTextModerationPlusTextInputCheckService
config.ResponseCheckService = DefaultTextModerationPlusTextOutputCheckService
case MultiModalGuard:
config.RequestCheckService = DefaultMultiModalGuardTextInputCheckService
config.RequestImageCheckService = DefaultMultiModalGuardImageInputCheckService
config.ResponseCheckService = DefaultMultiModalGuardTextOutputCheckService
}
config.RiskLevelBar = HighRisk
config.DenyCode = DefaultDenyCode
config.RequestContentJsonPath = DefaultRequestJsonPath
config.ResponseContentJsonPath = DefaultResponseJsonPath
config.ResponseStreamContentJsonPath = DefaultStreamingResponseJsonPath
config.ContentModerationLevelBar = MaxRisk
config.PromptAttackLevelBar = MaxRisk
config.SensitiveDataLevelBar = S4Sensitive
config.ModelHallucinationLevelBar = MaxRisk
config.MaliciousUrlLevelBar = MaxRisk
config.CustomLabelLevelBar = MaxRisk
config.Timeout = DefaultTimeout
config.BufferLimit = 1000
config.ApiType = ApiTextGeneration
config.ProviderType = ProviderOpenAI
config.RiskAction = "block"
}
func (config *AISecurityConfig) IncrementCounter(metricName string, inc uint64) {
counter, ok := config.Metrics[metricName]
if !ok {
counter = proxywasm.DefineCounterMetric(metricName)
config.Metrics[metricName] = counter
}
counter.Increment(inc)
}
func (config *AISecurityConfig) GetRequestCheckService(consumer string) string {
result := config.RequestCheckService
for _, obj := range config.ConsumerRequestCheckService {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if requestCheckService, ok := obj["requestCheckService"]; ok {
result, _ = requestCheckService.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetRequestImageCheckService(consumer string) string {
result := config.RequestImageCheckService
for _, obj := range config.ConsumerRequestCheckService {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if requestCheckService, ok := obj["requestImageCheckService"]; ok {
result, _ = requestCheckService.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetResponseCheckService(consumer string) string {
result := config.ResponseCheckService
for _, obj := range config.ConsumerResponseCheckService {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if responseCheckService, ok := obj["responseCheckService"]; ok {
result, _ = responseCheckService.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetResponseImageCheckService(consumer string) string {
result := config.ResponseImageCheckService
for _, obj := range config.ConsumerResponseCheckService {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if responseCheckService, ok := obj["responseImageCheckService"]; ok {
result, _ = responseCheckService.(string)
}
break
}
}
}
return result
}
// getMatchedConsumerRiskRule returns the first matched consumer rule using first-match semantics.
// It iterates ConsumerRiskLevel in order and returns the first rule whose matcher matches the consumer.
// Returns nil, false if no rule matches.
func (config *AISecurityConfig) getMatchedConsumerRiskRule(consumer string) (map[string]interface{}, bool) {
for _, obj := range config.ConsumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
return obj, true
}
}
}
return nil, false
}
func (config *AISecurityConfig) GetRiskLevelBar(consumer string) string {
result := config.RiskLevelBar
if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok {
if riskLevelBar, ok := rule["riskLevelBar"]; ok {
result, _ = riskLevelBar.(string)
}
}
return result
}
func (config *AISecurityConfig) GetContentModerationLevelBar(consumer string) string {
result := config.ContentModerationLevelBar
if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok {
if contentModerationLevelBar, ok := rule["contentModerationLevelBar"]; ok {
result, _ = contentModerationLevelBar.(string)
}
}
return result
}
func (config *AISecurityConfig) GetPromptAttackLevelBar(consumer string) string {
result := config.PromptAttackLevelBar
if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok {
if promptAttackLevelBar, ok := rule["promptAttackLevelBar"]; ok {
result, _ = promptAttackLevelBar.(string)
}
}
return result
}
func (config *AISecurityConfig) GetSensitiveDataLevelBar(consumer string) string {
result := config.SensitiveDataLevelBar
if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok {
if sensitiveDataLevelBar, ok := rule["sensitiveDataLevelBar"]; ok {
result, _ = sensitiveDataLevelBar.(string)
}
}
return result
}
func (config *AISecurityConfig) GetMaliciousUrlLevelBar(consumer string) string {
result := config.MaliciousUrlLevelBar
if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok {
if maliciousUrlLevelBar, ok := rule["maliciousUrlLevelBar"]; ok {
result, _ = maliciousUrlLevelBar.(string)
}
}
return result
}
func (config *AISecurityConfig) GetModelHallucinationLevelBar(consumer string) string {
result := config.ModelHallucinationLevelBar
if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok {
if modelHallucinationLevelBar, ok := rule["modelHallucinationLevelBar"]; ok {
result, _ = modelHallucinationLevelBar.(string)
}
}
return result
}
func (config *AISecurityConfig) GetCustomLabelLevelBar(consumer string) string {
result := config.CustomLabelLevelBar
if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok {
if customLabelLevelBar, ok := rule["customLabelLevelBar"]; ok {
result, _ = customLabelLevelBar.(string)
}
}
return result
}
func (config *AISecurityConfig) GetRiskAction(consumer string) string {
result := config.RiskAction
if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok {
if riskAction, ok := rule["riskAction"]; ok {
result, _ = riskAction.(string)
}
}
return result
}
// dimensionActionKey maps a detailType to the corresponding key used in consumerRiskLevel map.
// For example, SensitiveDataType -> "sensitiveDataAction".
func dimensionActionKey(detailType string) string {
switch detailType {
case ContentModerationType:
return "contentModerationAction"
case PromptAttackType:
return "promptAttackAction"
case SensitiveDataType:
return "sensitiveDataAction"
case MaliciousUrlDataType:
return "maliciousUrlAction"
case ModelHallucinationDataType:
return "modelHallucinationAction"
case CustomLabelType:
return "customLabelAction"
default:
return ""
}
}
// getGlobalDimensionAction returns the global dimension action field value for the given detailType.
func (config *AISecurityConfig) getGlobalDimensionAction(detailType string) string {
switch detailType {
case ContentModerationType:
return config.ContentModerationAction
case PromptAttackType:
return config.PromptAttackAction
case SensitiveDataType:
return config.SensitiveDataAction
case MaliciousUrlDataType:
return config.MaliciousUrlAction
case ModelHallucinationDataType:
return config.ModelHallucinationAction
case CustomLabelType:
return config.CustomLabelAction
default:
return ""
}
}
// enforceMaskBoundary downgrades mask to block for non-sensitiveData dimensions,
// since only sensitiveData supports actual mask/desensitization.
func enforceMaskBoundary(action, detailType, source string) (string, string) {
if action == "mask" && detailType != SensitiveDataType {
proxywasm.LogWarnf("mask action not supported for dimension %s, downgrading to block", detailType)
return "block", source
}
return action, source
}
// ResolveRiskActionByType resolves the final action for a given dimension type
// using 5-level priority: consumer_dimension > consumer_global > global_dimension > global_global > default(block).
// Returns (action, source) where source indicates which priority level the action came from.
func (config *AISecurityConfig) ResolveRiskActionByType(consumer string, detailType string) (string, string) {
dimKey := dimensionActionKey(detailType)
// 1. Check matched consumer rule
if rule, ok := config.getMatchedConsumerRiskRule(consumer); ok {
// 1a. consumer dimension action
if dimKey != "" {
if v, exists := rule[dimKey]; exists {
if s, ok := v.(string); ok && s != "" {
return enforceMaskBoundary(s, detailType, "consumer_dimension")
}
}
}
// 1b. consumer global riskAction
if v, exists := rule["riskAction"]; exists {
if s, ok := v.(string); ok && s != "" {
return enforceMaskBoundary(s, detailType, "consumer_global")
}
}
}
// 2. Global dimension action
globalDimAction := config.getGlobalDimensionAction(detailType)
if globalDimAction != "" {
return enforceMaskBoundary(globalDimAction, detailType, "global_dimension")
}
// 3. Global riskAction
if config.RiskAction != "" {
return enforceMaskBoundary(config.RiskAction, detailType, "global_global")
}
// 4. Default block
return "block", "default"
}
func LevelToInt(riskLevel string) int {
// First check against our defined constants
switch strings.ToLower(riskLevel) {
case MaxRisk, S4Sensitive:
return 4
case HighRisk, S3Sensitive:
return 3
case MediumRisk, S2Sensitive:
return 2
case LowRisk, S1Sensitive:
return 1
case NoRisk, NoSensitive:
return 0
default:
return -1
}
}
type RiskResult int
const (
RiskPass RiskResult = iota // 放行
RiskMask // 需要脱敏
RiskBlock // 需要拦截
)
// EvaluateRisk evaluates the risk of the given data and returns a RiskResult.
// For MultiModalGuard/MultiModalGuardForBase64, it uses the unified per-dimension
// action resolution flow (evaluateRiskMultiModal).
// For other actions (e.g. TextModerationPlus), it only checks RiskLevelBar.
func EvaluateRisk(action string, data Data, config AISecurityConfig, consumer string) RiskResult {
if action == MultiModalGuard || action == MultiModalGuardForBase64 {
return evaluateRiskMultiModal(data, config, consumer)
}
// TextModerationPlus and other non-MultiModalGuard actions: dimension actions not used
if LevelToInt(data.RiskLevel) < LevelToInt(config.GetRiskLevelBar(consumer)) {
return RiskPass
}
return RiskBlock
}
// evaluateRiskMultiModal implements the unified per-dimension risk evaluation for MultiModalGuard.
// It follows the design doc section 11.1-7 pseudocode:
// 1. Top-level compatibility gate (RiskLevel / AttackLevel)
// 2. Per-Detail dimension action resolution and threshold check
// 3. Data.Suggestion=block fallback
func evaluateRiskMultiModal(data Data, config AISecurityConfig, consumer string) RiskResult {
// 1. Top-level compatibility gate
if LevelToInt(data.RiskLevel) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) {
return RiskBlock
}
if LevelToInt(data.AttackLevel) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) {
return RiskBlock
}
// 2. Detail per-dimension evaluation
hasMask := false
for _, detail := range data.Detail {
dimAction, actionSource := config.ResolveRiskActionByType(consumer, detail.Type)
exceeds := detailExceedsThreshold(detail, config, consumer)
proxywasm.LogInfof("safecheck_risk_type=%s, safecheck_resolved_action=%s, safecheck_action_source=%s",
detail.Type, dimAction, actionSource)
if detailTriggersBlock(detail, dimAction, exceeds) {
return RiskBlock
}
// dimAction == "mask" (only sensitiveData effective; others already downgraded by enforceMaskBoundary)
if dimAction == "mask" && detail.Suggestion == "mask" {
if exceeds {
hasMask = true
} else {
proxywasm.LogInfof("safecheck_mask_skipped: type=%s, suggestion=%s, level=%s, threshold=%s",
detail.Type, detail.Suggestion, detail.Level, config.GetSensitiveDataLevelBar(consumer))
}
}
}
// 3. Data.Suggestion=block fallback
if data.Suggestion == "block" {
return RiskBlock
}
if hasMask {
return RiskMask
}
return RiskPass
}
// detailTriggersBlock returns whether this single detail should trigger blocking,
// given the resolved dimension action and threshold evaluation result.
func detailTriggersBlock(detail Detail, dimAction string, exceeds bool) bool {
if detail.Suggestion == "block" {
return true
}
if dimAction == "block" {
return exceeds
}
// dimAction == "mask": explicit mask suggestion is allowed to pass for desensitization.
if detail.Suggestion == "mask" {
return false
}
return exceeds
}
// detailExceedsThreshold checks if a single Detail's level exceeds the configured threshold
// for its Type.
func detailExceedsThreshold(detail Detail, config AISecurityConfig, consumer string) bool {
switch detail.Type {
case ContentModerationType:
return LevelToInt(detail.Level) >= LevelToInt(config.GetContentModerationLevelBar(consumer))
case PromptAttackType:
return LevelToInt(detail.Level) >= LevelToInt(config.GetPromptAttackLevelBar(consumer))
case SensitiveDataType:
return LevelToInt(detail.Level) >= LevelToInt(config.GetSensitiveDataLevelBar(consumer))
case MaliciousUrlDataType:
return LevelToInt(detail.Level) >= LevelToInt(config.GetMaliciousUrlLevelBar(consumer))
case ModelHallucinationDataType:
return LevelToInt(detail.Level) >= LevelToInt(config.GetModelHallucinationLevelBar(consumer))
case CustomLabelType:
return LevelToInt(detail.Level) >= LevelToInt(config.GetCustomLabelLevelBar(consumer))
default:
return false
}
}
func IsRiskLevelAcceptable(action string, data Data, config AISecurityConfig, consumer string) bool {
return EvaluateRisk(action, data, config, consumer) != RiskBlock
}
// ExtractDesensitization extracts the desensitization content from the first Detail
// with Type=sensitiveData and Suggestion=mask. Returns empty string if no such
// Detail exists, if the Detail has no Result entries, or if the desensitization
// content is empty.
func ExtractDesensitization(data Data) string {
for _, detail := range data.Detail {
if detail.Type == SensitiveDataType && detail.Suggestion == "mask" {
if len(detail.Result) > 0 && detail.Result[0].Ext.Desensitization != "" {
return detail.Result[0].Ext.Desensitization
}
}
}
return ""
}
type BlockedDetail struct {
Type string `json:"type"`
Level string `json:"level"`
}
type DenyResponseBody struct {
Code int `json:"code"`
DenyMessage string `json:"denyMessage,omitempty"`
BlockedDetails []BlockedDetail `json:"blockedDetails"`
}
func BuildDenyResponseBody(response Response, config AISecurityConfig, consumer string) ([]byte, error) {
details := GetUnacceptableDetail(response.Data, config, consumer)
blocked := make([]BlockedDetail, 0, len(details))
for _, d := range details {
blocked = append(blocked, BlockedDetail{
Type: d.Type,
Level: d.Level,
})
}
body := DenyResponseBody{
Code: response.Code,
DenyMessage: config.DenyMessage,
BlockedDetails: blocked,
}
return json.Marshal(body)
}
func GetUnacceptableDetail(data Data, config AISecurityConfig, consumer string) []Detail {
result := []Detail{}
for _, detail := range data.Detail {
dimAction, _ := config.ResolveRiskActionByType(consumer, detail.Type)
exceeds := detailExceedsThreshold(detail, config, consumer)
if detailTriggersBlock(detail, dimAction, exceeds) {
result = append(result, detail)
}
}
// Fallback: when the security service returns a top-level risk signal but no Detail entries,
// synthesise detail items from RiskLevel/AttackLevel so blockedDetails is never empty on a
// real block event.
if len(result) == 0 {
if LevelToInt(data.RiskLevel) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) {
result = append(result, Detail{
Type: ContentModerationType,
Level: data.RiskLevel,
Suggestion: "block",
})
}
if LevelToInt(data.AttackLevel) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) {
result = append(result, Detail{
Type: PromptAttackType,
Level: data.AttackLevel,
Suggestion: "block",
})
}
}
return result
}