support consumer specific check service || =support consumer specific check service (#2891)

This commit is contained in:
rinfx
2025-09-09 17:35:05 +08:00
committed by GitHub
parent 5384481704
commit 9f0f3de540

View File

@@ -108,6 +108,8 @@ type AISecurityConfig struct {
timeout uint32
bufferLimit int
metrics map[string]proxywasm.MetricCounter
consumerSpecificRequestCheckService map[string]string
consumerSpecificResponseCheckService map[string]string
}
func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) {
@@ -242,6 +244,14 @@ func parseConfig(json gjson.Result, config *AISecurityConfig) error {
} else {
config.bufferLimit = 1000
}
config.consumerSpecificRequestCheckService = make(map[string]string)
for k, v := range json.Get("consumerSpecificRequestCheckService").Map() {
config.consumerSpecificRequestCheckService[k] = v.String()
}
config.consumerSpecificResponseCheckService = make(map[string]string)
for k, v := range json.Get("consumerSpecificResponseCheckService").Map() {
config.consumerSpecificResponseCheckService[k] = v.String()
}
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: serviceName,
Port: servicePort,
@@ -261,6 +271,8 @@ func generateRandomID() string {
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig) types.Action {
consumer, _ := proxywasm.GetHttpRequestHeader("x-mse-consumer")
ctx.SetContext("consumer", consumer)
ctx.DisableReroute()
if !config.checkRequest {
log.Debugf("request checking is disabled")
@@ -347,6 +359,11 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
consumer, _ := ctx.GetContext("consumer").(string)
checkService, ok := config.consumerSpecificRequestCheckService[consumer]
if !ok {
checkService = config.requestCheckService
}
params := map[string]string{
"Format": "JSON",
"Version": "2022-03-02",
@@ -356,7 +373,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.requestCheckService,
"Service": checkService,
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece)),
}
if config.token != "" {
@@ -467,6 +484,11 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
log.Debugf("current content piece: %s", buffer)
consumer, _ := ctx.GetContext("consumer").(string)
checkService, ok := config.consumerSpecificResponseCheckService[consumer]
if !ok {
checkService = config.responseCheckService
}
params := map[string]string{
"Format": "JSON",
"Version": "2022-03-02",
@@ -476,7 +498,7 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.responseCheckService,
"Service": checkService,
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, ctx.GetContext("sessionID").(string), marshalStr(buffer)),
}
if config.token != "" {
@@ -593,6 +615,11 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
consumer, _ := ctx.GetContext("consumer").(string)
checkService, ok := config.consumerSpecificResponseCheckService[consumer]
if !ok {
checkService = config.responseCheckService
}
params := map[string]string{
"Format": "JSON",
"Version": "2022-03-02",
@@ -602,7 +629,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.responseCheckService,
"Service": checkService,
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece)),
}
if config.token != "" {