support consumer specific check service || =support consumer specific check service (#2891)

This commit is contained in:
rinfx
2025-09-09 17:35:05 +08:00
committed by GitHub
parent 5384481704
commit 9f0f3de540

View File

@@ -90,24 +90,26 @@ type Advice struct {
} }
type AISecurityConfig struct { type AISecurityConfig struct {
client wrapper.HttpClient client wrapper.HttpClient
ak string ak string
sk string sk string
token string token string
checkRequest bool checkRequest bool
requestCheckService string requestCheckService string
requestContentJsonPath string requestContentJsonPath string
checkResponse bool checkResponse bool
responseCheckService string responseCheckService string
responseContentJsonPath string responseContentJsonPath string
responseStreamContentJsonPath string responseStreamContentJsonPath string
denyCode int64 denyCode int64
denyMessage string denyMessage string
protocolOriginal bool protocolOriginal bool
riskLevelBar string riskLevelBar string
timeout uint32 timeout uint32
bufferLimit int bufferLimit int
metrics map[string]proxywasm.MetricCounter metrics map[string]proxywasm.MetricCounter
consumerSpecificRequestCheckService map[string]string
consumerSpecificResponseCheckService map[string]string
} }
func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) { func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) {
@@ -242,6 +244,14 @@ func parseConfig(json gjson.Result, config *AISecurityConfig) error {
} else { } else {
config.bufferLimit = 1000 config.bufferLimit = 1000
} }
config.consumerSpecificRequestCheckService = make(map[string]string)
for k, v := range json.Get("consumerSpecificRequestCheckService").Map() {
config.consumerSpecificRequestCheckService[k] = v.String()
}
config.consumerSpecificResponseCheckService = make(map[string]string)
for k, v := range json.Get("consumerSpecificResponseCheckService").Map() {
config.consumerSpecificResponseCheckService[k] = v.String()
}
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{ config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: serviceName, FQDN: serviceName,
Port: servicePort, Port: servicePort,
@@ -261,6 +271,8 @@ func generateRandomID() string {
} }
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig) types.Action { func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig) types.Action {
consumer, _ := proxywasm.GetHttpRequestHeader("x-mse-consumer")
ctx.SetContext("consumer", consumer)
ctx.DisableReroute() ctx.DisableReroute()
if !config.checkRequest { if !config.checkRequest {
log.Debugf("request checking is disabled") log.Debugf("request checking is disabled")
@@ -347,6 +359,11 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
contentPiece := content[contentIndex:nextContentIndex] contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece) log.Debugf("current content piece: %s", contentPiece)
consumer, _ := ctx.GetContext("consumer").(string)
checkService, ok := config.consumerSpecificRequestCheckService[consumer]
if !ok {
checkService = config.requestCheckService
}
params := map[string]string{ params := map[string]string{
"Format": "JSON", "Format": "JSON",
"Version": "2022-03-02", "Version": "2022-03-02",
@@ -356,7 +373,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
"Action": "TextModerationPlus", "Action": "TextModerationPlus",
"AccessKeyId": config.ak, "AccessKeyId": config.ak,
"Timestamp": timestamp, "Timestamp": timestamp,
"Service": config.requestCheckService, "Service": checkService,
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece)), "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece)),
} }
if config.token != "" { if config.token != "" {
@@ -467,6 +484,11 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16) randomID, _ := generateHexID(16)
log.Debugf("current content piece: %s", buffer) log.Debugf("current content piece: %s", buffer)
consumer, _ := ctx.GetContext("consumer").(string)
checkService, ok := config.consumerSpecificResponseCheckService[consumer]
if !ok {
checkService = config.responseCheckService
}
params := map[string]string{ params := map[string]string{
"Format": "JSON", "Format": "JSON",
"Version": "2022-03-02", "Version": "2022-03-02",
@@ -476,7 +498,7 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfi
"Action": "TextModerationPlus", "Action": "TextModerationPlus",
"AccessKeyId": config.ak, "AccessKeyId": config.ak,
"Timestamp": timestamp, "Timestamp": timestamp,
"Service": config.responseCheckService, "Service": checkService,
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, ctx.GetContext("sessionID").(string), marshalStr(buffer)), "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, ctx.GetContext("sessionID").(string), marshalStr(buffer)),
} }
if config.token != "" { if config.token != "" {
@@ -593,6 +615,11 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
contentPiece := content[contentIndex:nextContentIndex] contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece) log.Debugf("current content piece: %s", contentPiece)
consumer, _ := ctx.GetContext("consumer").(string)
checkService, ok := config.consumerSpecificResponseCheckService[consumer]
if !ok {
checkService = config.responseCheckService
}
params := map[string]string{ params := map[string]string{
"Format": "JSON", "Format": "JSON",
"Version": "2022-03-02", "Version": "2022-03-02",
@@ -602,7 +629,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
"Action": "TextModerationPlus", "Action": "TextModerationPlus",
"AccessKeyId": config.ak, "AccessKeyId": config.ak,
"Timestamp": timestamp, "Timestamp": timestamp,
"Service": config.responseCheckService, "Service": checkService,
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece)), "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece)),
} }
if config.token != "" { if config.token != "" {