mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 12:47:28 +08:00
fix special charactor handle in ai-security-guard plugin (#1394)
This commit is contained in:
@@ -81,7 +81,7 @@ func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64)
|
|||||||
|
|
||||||
func urlEncoding(rawStr string) string {
|
func urlEncoding(rawStr string) string {
|
||||||
encodedStr := url.PathEscape(rawStr)
|
encodedStr := url.PathEscape(rawStr)
|
||||||
encodedStr = strings.ReplaceAll(encodedStr, "+", "%20")
|
encodedStr = strings.ReplaceAll(encodedStr, "+", "%2B")
|
||||||
encodedStr = strings.ReplaceAll(encodedStr, ":", "%3A")
|
encodedStr = strings.ReplaceAll(encodedStr, ":", "%3A")
|
||||||
encodedStr = strings.ReplaceAll(encodedStr, "=", "%3D")
|
encodedStr = strings.ReplaceAll(encodedStr, "=", "%3D")
|
||||||
encodedStr = strings.ReplaceAll(encodedStr, "&", "%26")
|
encodedStr = strings.ReplaceAll(encodedStr, "&", "%26")
|
||||||
@@ -106,7 +106,7 @@ func getSign(params map[string]string, secret string) string {
|
|||||||
})
|
})
|
||||||
canonicalStr := strings.Join(paramArray, "&")
|
canonicalStr := strings.Join(paramArray, "&")
|
||||||
signStr := "POST&%2F&" + urlEncoding(canonicalStr)
|
signStr := "POST&%2F&" + urlEncoding(canonicalStr)
|
||||||
// proxywasm.LogInfo(signStr)
|
proxywasm.LogDebugf("String to sign is: %s", signStr)
|
||||||
return hmacSha1(signStr, secret)
|
return hmacSha1(signStr, secret)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,10 +196,11 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
|
|||||||
|
|
||||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
||||||
log.Debugf("checking request body...")
|
log.Debugf("checking request body...")
|
||||||
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
|
content := gjson.GetBytes(body, config.requestContentJsonPath).Raw
|
||||||
model := gjson.GetBytes(body, "model").Raw
|
model := gjson.GetBytes(body, "model").Raw
|
||||||
ctx.SetContext("requestModel", model)
|
ctx.SetContext("requestModel", model)
|
||||||
if content != "" {
|
log.Debugf("Raw response content is: %s", content)
|
||||||
|
if len(content) > 0 {
|
||||||
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||||
randomID, _ := generateHexID(16)
|
randomID, _ := generateHexID(16)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
@@ -212,7 +213,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
|||||||
"AccessKeyId": config.ak,
|
"AccessKeyId": config.ak,
|
||||||
"Timestamp": timestamp,
|
"Timestamp": timestamp,
|
||||||
"Service": config.requestCheckService,
|
"Service": config.requestCheckService,
|
||||||
"ServiceParameters": `{"content": "` + content + `"}`,
|
"ServiceParameters": fmt.Sprintf(`{"content": %s}`, content),
|
||||||
}
|
}
|
||||||
signature := getSign(params, config.sk+"&")
|
signature := getSign(params, config.sk+"&")
|
||||||
reqParams := url.Values{}
|
reqParams := url.Values{}
|
||||||
@@ -339,7 +340,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
|||||||
if isStreamingResponse {
|
if isStreamingResponse {
|
||||||
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
|
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
|
||||||
} else {
|
} else {
|
||||||
content = gjson.GetBytes(body, config.responseContentJsonPath).String()
|
content = gjson.GetBytes(body, config.responseContentJsonPath).Raw
|
||||||
}
|
}
|
||||||
log.Debugf("Raw response content is: %s", content)
|
log.Debugf("Raw response content is: %s", content)
|
||||||
if len(content) > 0 {
|
if len(content) > 0 {
|
||||||
@@ -355,7 +356,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
|||||||
"AccessKeyId": config.ak,
|
"AccessKeyId": config.ak,
|
||||||
"Timestamp": timestamp,
|
"Timestamp": timestamp,
|
||||||
"Service": config.responseCheckService,
|
"Service": config.responseCheckService,
|
||||||
"ServiceParameters": `{"content": "` + content + `"}`,
|
"ServiceParameters": fmt.Sprintf(`{"content": %s}`, content),
|
||||||
}
|
}
|
||||||
signature := getSign(params, config.sk+"&")
|
signature := getSign(params, config.sk+"&")
|
||||||
reqParams := url.Values{}
|
reqParams := url.Values{}
|
||||||
@@ -400,10 +401,10 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
|||||||
jsonData = []byte(denyMessage)
|
jsonData = []byte(denyMessage)
|
||||||
} else if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
|
} else if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
|
||||||
randomID := generateRandomID()
|
randomID := generateRandomID()
|
||||||
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String(), randomID, model))
|
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model))
|
||||||
} else {
|
} else {
|
||||||
randomID := generateRandomID()
|
randomID := generateRandomID()
|
||||||
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String()))
|
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, denyMessage))
|
||||||
}
|
}
|
||||||
delete(hdsMap, "content-length")
|
delete(hdsMap, "content-length")
|
||||||
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
|
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
|
||||||
@@ -432,10 +433,10 @@ func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
|
|||||||
strChunks := []string{}
|
strChunks := []string{}
|
||||||
for _, chunk := range chunks {
|
for _, chunk := range chunks {
|
||||||
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
|
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
|
||||||
jsonObj := gjson.GetBytes(chunk, jsonPath)
|
jsonRaw := gjson.GetBytes(chunk, jsonPath).Raw
|
||||||
if jsonObj.Exists() {
|
if len(jsonRaw) > 2 {
|
||||||
strChunks = append(strChunks, jsonObj.String())
|
strChunks = append(strChunks, jsonRaw[1:len(jsonRaw)-1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return strings.Join(strChunks, "")
|
return fmt.Sprintf(`"%s"`, strings.Join(strChunks, ""))
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user