|
|
|
@@ -29,11 +29,12 @@ func main() {}
|
|
|
|
|
func init() {
|
|
|
|
|
wrapper.SetCtx(
|
|
|
|
|
"ai-security-guard",
|
|
|
|
|
wrapper.ParseConfigBy(parseConfig),
|
|
|
|
|
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
|
|
|
|
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
|
|
|
|
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
|
|
|
|
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
|
|
|
|
|
wrapper.ParseConfig(parseConfig),
|
|
|
|
|
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
|
|
|
|
|
wrapper.ProcessRequestBody(onHttpRequestBody),
|
|
|
|
|
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
|
|
|
|
|
wrapper.ProcessStreamingResponseBody(onHttpStreamingResponseBody),
|
|
|
|
|
wrapper.ProcessResponseBody(onHttpResponseBody),
|
|
|
|
|
)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -105,6 +106,7 @@ type AISecurityConfig struct {
|
|
|
|
|
protocolOriginal bool
|
|
|
|
|
riskLevelBar string
|
|
|
|
|
timeout uint32
|
|
|
|
|
bufferLimit int
|
|
|
|
|
metrics map[string]proxywasm.MetricCounter
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -175,7 +177,7 @@ func generateHexID(length int) (string, error) {
|
|
|
|
|
return hex.EncodeToString(bytes), nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func parseConfig(json gjson.Result, config *AISecurityConfig, log log.Log) error {
|
|
|
|
|
func parseConfig(json gjson.Result, config *AISecurityConfig) error {
|
|
|
|
|
serviceName := json.Get("serviceName").String()
|
|
|
|
|
servicePort := json.Get("servicePort").Int()
|
|
|
|
|
serviceHost := json.Get("serviceHost").String()
|
|
|
|
@@ -235,6 +237,11 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log log.Log) error
|
|
|
|
|
} else {
|
|
|
|
|
config.timeout = DefaultTimeout
|
|
|
|
|
}
|
|
|
|
|
if obj := json.Get("bufferLimit"); obj.Exists() {
|
|
|
|
|
config.bufferLimit = int(obj.Int())
|
|
|
|
|
} else {
|
|
|
|
|
config.bufferLimit = 1000
|
|
|
|
|
}
|
|
|
|
|
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
|
|
|
|
FQDN: serviceName,
|
|
|
|
|
Port: servicePort,
|
|
|
|
@@ -253,7 +260,7 @@ func generateRandomID() string {
|
|
|
|
|
return "chatcmpl-" + string(b)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log log.Log) types.Action {
|
|
|
|
|
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig) types.Action {
|
|
|
|
|
ctx.DisableReroute()
|
|
|
|
|
if !config.checkRequest {
|
|
|
|
|
log.Debugf("request checking is disabled")
|
|
|
|
@@ -262,7 +269,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
|
|
|
|
|
return types.ActionContinue
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log log.Log) types.Action {
|
|
|
|
|
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte) types.Action {
|
|
|
|
|
log.Debugf("checking request body...")
|
|
|
|
|
startTime := time.Now().UnixMilli()
|
|
|
|
|
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
|
|
|
|
@@ -305,7 +312,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
|
|
|
|
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
|
|
|
|
denyMessage = response.Data.Advice[0].Answer
|
|
|
|
|
}
|
|
|
|
|
marshalledDenyMessage := marshalStr(denyMessage, log)
|
|
|
|
|
marshalledDenyMessage := marshalStr(denyMessage)
|
|
|
|
|
if config.protocolOriginal {
|
|
|
|
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
|
|
|
|
} else if gjson.GetBytes(body, "stream").Bool() {
|
|
|
|
@@ -350,7 +357,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
|
|
|
|
"AccessKeyId": config.ak,
|
|
|
|
|
"Timestamp": timestamp,
|
|
|
|
|
"Service": config.requestCheckService,
|
|
|
|
|
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece, log)),
|
|
|
|
|
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece)),
|
|
|
|
|
}
|
|
|
|
|
if config.token != "" {
|
|
|
|
|
params["SecurityToken"] = config.token
|
|
|
|
@@ -371,7 +378,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
|
|
|
|
return types.ActionPause
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log log.Log) types.Action {
|
|
|
|
|
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig) types.Action {
|
|
|
|
|
if !config.checkResponse {
|
|
|
|
|
log.Debugf("response checking is disabled")
|
|
|
|
|
ctx.DontReadResponseBody()
|
|
|
|
@@ -383,10 +390,126 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
|
|
|
|
|
ctx.DontReadResponseBody()
|
|
|
|
|
return types.ActionContinue
|
|
|
|
|
}
|
|
|
|
|
return types.HeaderStopIteration
|
|
|
|
|
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
|
|
|
|
ctx.SetContext("end_of_stream_received", false)
|
|
|
|
|
ctx.SetContext("during_call", false)
|
|
|
|
|
ctx.SetContext("risk_detected", false)
|
|
|
|
|
sessionID, _ := generateHexID(20)
|
|
|
|
|
ctx.SetContext("sessionID", sessionID)
|
|
|
|
|
if strings.Contains(contentType, "text/event-stream") {
|
|
|
|
|
ctx.NeedPauseStreamingResponse()
|
|
|
|
|
return types.ActionContinue
|
|
|
|
|
} else {
|
|
|
|
|
ctx.BufferResponseBody()
|
|
|
|
|
return types.HeaderStopIteration
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log log.Log) types.Action {
|
|
|
|
|
func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, data []byte, endOfStream bool) []byte {
|
|
|
|
|
var bufferQueue [][]byte
|
|
|
|
|
var singleCall func()
|
|
|
|
|
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
|
|
|
|
log.Info(string(responseBody))
|
|
|
|
|
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
|
|
|
|
if ctx.GetContext("end_of_stream_received").(bool) {
|
|
|
|
|
proxywasm.ResumeHttpResponse()
|
|
|
|
|
}
|
|
|
|
|
ctx.SetContext("during_call", false)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
var response Response
|
|
|
|
|
err := json.Unmarshal(responseBody, &response)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Error("failed to unmarshal aliyun content security response at response phase")
|
|
|
|
|
if ctx.GetContext("end_of_stream_received").(bool) {
|
|
|
|
|
proxywasm.ResumeHttpResponse()
|
|
|
|
|
}
|
|
|
|
|
ctx.SetContext("during_call", false)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if riskLevelToInt(response.Data.RiskLevel) >= riskLevelToInt(config.riskLevelBar) {
|
|
|
|
|
denyMessage := DefaultDenyMessage
|
|
|
|
|
if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
|
|
|
|
denyMessage = "\n" + response.Data.Advice[0].Answer
|
|
|
|
|
} else if config.denyMessage != "" {
|
|
|
|
|
denyMessage = config.denyMessage
|
|
|
|
|
}
|
|
|
|
|
marshalledDenyMessage := marshalStr(denyMessage)
|
|
|
|
|
randomID := generateRandomID()
|
|
|
|
|
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
|
|
|
|
proxywasm.InjectEncodedDataToFilterChain(jsonData, true)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
endStream := ctx.GetContext("end_of_stream_received").(bool) && ctx.BufferQueueSize() == 0
|
|
|
|
|
proxywasm.InjectEncodedDataToFilterChain(bytes.Join(bufferQueue, []byte("")), endStream)
|
|
|
|
|
bufferQueue = [][]byte{}
|
|
|
|
|
if !endStream {
|
|
|
|
|
ctx.SetContext("during_call", false)
|
|
|
|
|
singleCall()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
singleCall = func() {
|
|
|
|
|
if ctx.GetContext("during_call").(bool) {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if ctx.BufferQueueSize() >= config.bufferLimit || ctx.GetContext("end_of_stream_received").(bool) {
|
|
|
|
|
ctx.SetContext("during_call", true)
|
|
|
|
|
var buffer string
|
|
|
|
|
for ctx.BufferQueueSize() > 0 {
|
|
|
|
|
front := ctx.PopBuffer()
|
|
|
|
|
bufferQueue = append(bufferQueue, front)
|
|
|
|
|
msg := gjson.GetBytes(front, config.responseStreamContentJsonPath).String()
|
|
|
|
|
buffer += msg
|
|
|
|
|
if len([]rune(buffer)) >= config.bufferLimit {
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
|
|
|
|
randomID, _ := generateHexID(16)
|
|
|
|
|
log.Debugf("current content piece: %s", buffer)
|
|
|
|
|
params := map[string]string{
|
|
|
|
|
"Format": "JSON",
|
|
|
|
|
"Version": "2022-03-02",
|
|
|
|
|
"SignatureMethod": "Hmac-SHA1",
|
|
|
|
|
"SignatureNonce": randomID,
|
|
|
|
|
"SignatureVersion": "1.0",
|
|
|
|
|
"Action": "TextModerationPlus",
|
|
|
|
|
"AccessKeyId": config.ak,
|
|
|
|
|
"Timestamp": timestamp,
|
|
|
|
|
"Service": config.responseCheckService,
|
|
|
|
|
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, ctx.GetContext("sessionID").(string), marshalStr(buffer)),
|
|
|
|
|
}
|
|
|
|
|
if config.token != "" {
|
|
|
|
|
params["SecurityToken"] = config.token
|
|
|
|
|
}
|
|
|
|
|
signature := getSign(params, config.sk+"&")
|
|
|
|
|
reqParams := url.Values{}
|
|
|
|
|
for k, v := range params {
|
|
|
|
|
reqParams.Add(k, v)
|
|
|
|
|
}
|
|
|
|
|
reqParams.Add("Signature", signature)
|
|
|
|
|
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorf("failed call the safe check service: %v", err)
|
|
|
|
|
if ctx.GetContext("end_of_stream_received").(bool) {
|
|
|
|
|
proxywasm.ResumeHttpResponse()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if !ctx.GetContext("risk_detected").(bool) {
|
|
|
|
|
ctx.PushBuffer(data)
|
|
|
|
|
ctx.SetContext("end_of_stream_received", endOfStream)
|
|
|
|
|
if !ctx.GetContext("during_call").(bool) {
|
|
|
|
|
singleCall()
|
|
|
|
|
}
|
|
|
|
|
} else if endOfStream {
|
|
|
|
|
proxywasm.ResumeHttpResponse()
|
|
|
|
|
}
|
|
|
|
|
return []byte{}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte) types.Action {
|
|
|
|
|
log.Debugf("checking response body...")
|
|
|
|
|
startTime := time.Now().UnixMilli()
|
|
|
|
|
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
|
|
|
@@ -436,7 +559,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
|
|
|
|
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
|
|
|
|
denyMessage = response.Data.Advice[0].Answer
|
|
|
|
|
}
|
|
|
|
|
marshalledDenyMessage := marshalStr(denyMessage, log)
|
|
|
|
|
marshalledDenyMessage := marshalStr(denyMessage)
|
|
|
|
|
if config.protocolOriginal {
|
|
|
|
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
|
|
|
|
} else if isStreamingResponse {
|
|
|
|
@@ -480,7 +603,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
|
|
|
|
"AccessKeyId": config.ak,
|
|
|
|
|
"Timestamp": timestamp,
|
|
|
|
|
"Service": config.responseCheckService,
|
|
|
|
|
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece, log)),
|
|
|
|
|
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece)),
|
|
|
|
|
}
|
|
|
|
|
if config.token != "" {
|
|
|
|
|
params["SecurityToken"] = config.token
|
|
|
|
@@ -511,7 +634,7 @@ func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
|
|
|
|
|
return strings.Join(strChunks, "")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func marshalStr(raw string, log log.Log) string {
|
|
|
|
|
func marshalStr(raw string) string {
|
|
|
|
|
helper := map[string]string{
|
|
|
|
|
"placeholder": raw,
|
|
|
|
|
}
|
|
|
|
|