fix special charactor handle in ai-security-guard plugin (#1394)

This commit is contained in:
rinfx
2024-10-18 16:32:48 +08:00
committed by GitHub
parent 49bb5ec2b9
commit 32e5a59ae0

View File

@@ -81,7 +81,7 @@ func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64)
func urlEncoding(rawStr string) string {
encodedStr := url.PathEscape(rawStr)
encodedStr = strings.ReplaceAll(encodedStr, "+", "%20")
encodedStr = strings.ReplaceAll(encodedStr, "+", "%2B")
encodedStr = strings.ReplaceAll(encodedStr, ":", "%3A")
encodedStr = strings.ReplaceAll(encodedStr, "=", "%3D")
encodedStr = strings.ReplaceAll(encodedStr, "&", "%26")
@@ -106,7 +106,7 @@ func getSign(params map[string]string, secret string) string {
})
canonicalStr := strings.Join(paramArray, "&")
signStr := "POST&%2F&" + urlEncoding(canonicalStr)
// proxywasm.LogInfo(signStr)
proxywasm.LogDebugf("String to sign is: %s", signStr)
return hmacSha1(signStr, secret)
}
@@ -196,10 +196,11 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
log.Debugf("checking request body...")
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
content := gjson.GetBytes(body, config.requestContentJsonPath).Raw
model := gjson.GetBytes(body, "model").Raw
ctx.SetContext("requestModel", model)
if content != "" {
log.Debugf("Raw response content is: %s", content)
if len(content) > 0 {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
params := map[string]string{
@@ -212,7 +213,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.requestCheckService,
"ServiceParameters": `{"content": "` + content + `"}`,
"ServiceParameters": fmt.Sprintf(`{"content": %s}`, content),
}
signature := getSign(params, config.sk+"&")
reqParams := url.Values{}
@@ -339,7 +340,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
if isStreamingResponse {
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
} else {
content = gjson.GetBytes(body, config.responseContentJsonPath).String()
content = gjson.GetBytes(body, config.responseContentJsonPath).Raw
}
log.Debugf("Raw response content is: %s", content)
if len(content) > 0 {
@@ -355,7 +356,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.responseCheckService,
"ServiceParameters": `{"content": "` + content + `"}`,
"ServiceParameters": fmt.Sprintf(`{"content": %s}`, content),
}
signature := getSign(params, config.sk+"&")
reqParams := url.Values{}
@@ -400,10 +401,10 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
jsonData = []byte(denyMessage)
} else if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String(), randomID, model))
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model))
} else {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String()))
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, denyMessage))
}
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
@@ -432,10 +433,10 @@ func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
strChunks := []string{}
for _, chunk := range chunks {
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
jsonObj := gjson.GetBytes(chunk, jsonPath)
if jsonObj.Exists() {
strChunks = append(strChunks, jsonObj.String())
jsonRaw := gjson.GetBytes(chunk, jsonPath).Raw
if len(jsonRaw) > 2 {
strChunks = append(strChunks, jsonRaw[1:len(jsonRaw)-1])
}
}
return strings.Join(strChunks, "")
return fmt.Sprintf(`"%s"`, strings.Join(strChunks, ""))
}