Improve ai plugins (#1657)

Co-authored-by: Kent Dong <ch3cho@qq.com>
This commit is contained in:
rinfx
2025-01-09 22:04:51 +08:00
committed by GitHub
parent 2a89c3bb70
commit ea0d5e7564
5 changed files with 31 additions and 34 deletions

View File

@@ -41,9 +41,9 @@ const (
LowRisk = "low"
NoRisk = "none"
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}]}`
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
DefaultRequestCheckService = "llm_query_moderation"
@@ -262,8 +262,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
log.Debugf("checking request body...")
startTime := time.Now().UnixMilli()
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
model := gjson.GetBytes(body, "model").String()
ctx.SetContext("requestModel", model)
log.Debugf("Raw request content is: %s", content)
if len(content) == 0 {
log.Info("request content is empty. skip")
@@ -308,11 +306,11 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
@@ -369,15 +367,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
return types.ActionPause
}
func convertHeaders(hs [][2]string) map[string][]string {
ret := make(map[string][]string)
for _, h := range hs {
k, v := strings.ToLower(h[0]), h[1]
ret[k] = append(ret[k], v)
}
return ret
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
if !config.checkResponse {
log.Debugf("response checking is disabled")
@@ -398,7 +387,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
startTime := time.Now().UnixMilli()
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
isStreamingResponse := strings.Contains(contentType, "event-stream")
model := ctx.GetStringContext("requestModel", "unknown")
var content string
if isStreamingResponse {
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
@@ -449,11 +437,11 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if isStreamingResponse {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
config.incrementCounter("ai_sec_response_deny", 1)