diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index 7fc91be43..a2005d8ed 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -90,24 +90,26 @@ type Advice struct { } type AISecurityConfig struct { - client wrapper.HttpClient - ak string - sk string - token string - checkRequest bool - requestCheckService string - requestContentJsonPath string - checkResponse bool - responseCheckService string - responseContentJsonPath string - responseStreamContentJsonPath string - denyCode int64 - denyMessage string - protocolOriginal bool - riskLevelBar string - timeout uint32 - bufferLimit int - metrics map[string]proxywasm.MetricCounter + client wrapper.HttpClient + ak string + sk string + token string + checkRequest bool + requestCheckService string + requestContentJsonPath string + checkResponse bool + responseCheckService string + responseContentJsonPath string + responseStreamContentJsonPath string + denyCode int64 + denyMessage string + protocolOriginal bool + riskLevelBar string + timeout uint32 + bufferLimit int + metrics map[string]proxywasm.MetricCounter + consumerSpecificRequestCheckService map[string]string + consumerSpecificResponseCheckService map[string]string } func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) { @@ -242,6 +244,14 @@ func parseConfig(json gjson.Result, config *AISecurityConfig) error { } else { config.bufferLimit = 1000 } + config.consumerSpecificRequestCheckService = make(map[string]string) + for k, v := range json.Get("consumerSpecificRequestCheckService").Map() { + config.consumerSpecificRequestCheckService[k] = v.String() + } + config.consumerSpecificResponseCheckService = make(map[string]string) + for k, v := range json.Get("consumerSpecificResponseCheckService").Map() { + config.consumerSpecificResponseCheckService[k] = v.String() + } config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{ FQDN: serviceName, Port: servicePort, @@ -261,6 +271,8 @@ func generateRandomID() string { } func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig) types.Action { + consumer, _ := proxywasm.GetHttpRequestHeader("x-mse-consumer") + ctx.SetContext("consumer", consumer) ctx.DisableReroute() if !config.checkRequest { log.Debugf("request checking is disabled") @@ -347,6 +359,11 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] contentPiece := content[contentIndex:nextContentIndex] contentIndex = nextContentIndex log.Debugf("current content piece: %s", contentPiece) + consumer, _ := ctx.GetContext("consumer").(string) + checkService, ok := config.consumerSpecificRequestCheckService[consumer] + if !ok { + checkService = config.requestCheckService + } params := map[string]string{ "Format": "JSON", "Version": "2022-03-02", @@ -356,7 +373,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] "Action": "TextModerationPlus", "AccessKeyId": config.ak, "Timestamp": timestamp, - "Service": config.requestCheckService, + "Service": checkService, "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece)), } if config.token != "" { @@ -467,6 +484,11 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") randomID, _ := generateHexID(16) log.Debugf("current content piece: %s", buffer) + consumer, _ := ctx.GetContext("consumer").(string) + checkService, ok := config.consumerSpecificResponseCheckService[consumer] + if !ok { + checkService = config.responseCheckService + } params := map[string]string{ "Format": "JSON", "Version": "2022-03-02", @@ -476,7 +498,7 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi "Action": "TextModerationPlus", "AccessKeyId": config.ak, "Timestamp": timestamp, - "Service": config.responseCheckService, + "Service": checkService, "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, ctx.GetContext("sessionID").(string), marshalStr(buffer)), } if config.token != "" { @@ -593,6 +615,11 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ contentPiece := content[contentIndex:nextContentIndex] contentIndex = nextContentIndex log.Debugf("current content piece: %s", contentPiece) + consumer, _ := ctx.GetContext("consumer").(string) + checkService, ok := config.consumerSpecificResponseCheckService[consumer] + if !ok { + checkService = config.responseCheckService + } params := map[string]string{ "Format": "JSON", "Version": "2022-03-02", @@ -602,7 +629,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ "Action": "TextModerationPlus", "AccessKeyId": config.ak, "Timestamp": timestamp, - "Service": config.responseCheckService, + "Service": checkService, "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece)), } if config.token != "" {