Standardize the data structure returned by the AI security security a… (#1344)

Co-authored-by: Kent Dong <ch3cho@qq.com>
This commit is contained in:
Benny
2024-09-26 11:07:44 +08:00
committed by GitHub
parent af4e34b7ed
commit 260772926c
2 changed files with 44 additions and 12 deletions

View File

@@ -30,7 +30,6 @@ description: 阿里云内容安全检测
| `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 |
| `denyMessage` | string | optional | openai格式的流失/非流式响应,回答内容为阿里云内容安全的建议回答 | 指定内容非法时的响应内容 |
## 配置示例
### 前提条件
由于插件中需要调用阿里云内容安全服务所以需要先创建一个DNS类型的服务例如
@@ -123,7 +122,7 @@ curl http://localhost/v1/chat/completions \
```json
{
"id": "chatcmpl-123",
"id": "chatcmpl-AAy3hK1dE4ODaegbGOMoC9VY4Sizv",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-4o-mini",

View File

@@ -9,6 +9,7 @@ import (
"encoding/hex"
"errors"
"fmt"
mrand "math/rand"
"net/http"
"net/url"
"sort"
@@ -33,10 +34,10 @@ func main() {
}
const (
OpenAIResponseFormat = `{"id": "chatcmpl-123","object": "chat.completion","model": "gpt-4o-mini","choices": [{"index": 0,"message": {"role": "assistant","content": "%s"},"logprobs": null,"finish_reason": "stop"}]}`
OpenAIStreamResponseChunk = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4o-mini", "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4o-mini", "choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":%s,"choices":[{"index":0,"message":{"role":"assistant","content":%s},"logprobs":null,"finish_reason":"stop"}]}`
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":%s,"choices":[{"index":0,"delta":{"role":"assistant","content":%s},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model": %s,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
TracingPrefix = "trace_span_tag."
@@ -168,18 +169,32 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) e
return nil
}
func generateRandomID() string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, 29)
for i := range b {
b[i] = charset[mrand.Intn(len(charset))]
}
return "chatcmpl-" + string(b)
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
if !config.checkRequest {
log.Debugf("request checking is disabled")
ctx.DontReadRequestBody()
}
if !config.checkResponse {
log.Debugf("response checking is disabled")
ctx.DontReadResponseBody()
}
return types.ActionContinue
}
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
proxywasm.LogDebugf("checking request body...")
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
model := gjson.GetBytes(body, "model").Raw
ctx.SetContext("requestModel", model)
if content != "" {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
@@ -201,7 +216,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
respData := gjson.GetBytes(responseBody, "Data")
if respData.Exists() {
@@ -214,13 +229,17 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
if config.denyMessage != "" {
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(config.denyMessage), -1)
} else {
answer := respAdvice.Array()[0].Get("Answer").Raw
if gjson.GetBytes(body, "stream").Bool() {
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, answer))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, answer))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
}
} else {
proxywasm.ResumeHttpRequest()
@@ -230,8 +249,13 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
}
},
)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
return types.ActionContinue
}
return types.ActionPause
} else {
proxywasm.LogDebugf("request content is empty. skip")
return types.ActionContinue
}
}
@@ -271,8 +295,10 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
}
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
proxywasm.LogDebugf("checking response body...")
hdsMap := ctx.GetContext("headers").(map[string][]string)
isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream")
model := ctx.GetStringContext("requestModel", "unknown")
var content string
if isStreamingResponse {
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
@@ -301,7 +327,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
defer proxywasm.ResumeHttpResponse()
respData := gjson.GetBytes(responseBody, "Data")
@@ -314,9 +340,11 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
jsonData = []byte(config.denyMessage)
} else {
if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String()))
} else {
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String()))
}
}
delete(hdsMap, "content-length")
@@ -330,8 +358,13 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
}
},
)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
return types.ActionContinue
}
return types.ActionPause
} else {
proxywasm.LogDebugf("request content is empty. skip")
return types.ActionContinue
}
}