diff --git a/plugins/wasm-go/extensions/ai-security-guard/README.md b/plugins/wasm-go/extensions/ai-security-guard/README.md index 5a8f753f3..8b63213d5 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README.md @@ -30,7 +30,6 @@ description: 阿里云内容安全检测 | `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 | | `denyMessage` | string | optional | openai格式的流失/非流式响应,回答内容为阿里云内容安全的建议回答 | 指定内容非法时的响应内容 | - ## 配置示例 ### 前提条件 由于插件中需要调用阿里云内容安全服务,所以需要先创建一个DNS类型的服务,例如: @@ -123,7 +122,7 @@ curl http://localhost/v1/chat/completions \ ```json { - "id": "chatcmpl-123", + "id": "chatcmpl-AAy3hK1dE4ODaegbGOMoC9VY4Sizv", "object": "chat.completion", "created": 1677652288, "model": "gpt-4o-mini", diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index e1ccfeb57..4640596c8 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -9,6 +9,7 @@ import ( "encoding/hex" "errors" "fmt" + mrand "math/rand" "net/http" "net/url" "sort" @@ -33,10 +34,10 @@ func main() { } const ( - OpenAIResponseFormat = `{"id": "chatcmpl-123","object": "chat.completion","model": "gpt-4o-mini","choices": [{"index": 0,"message": {"role": "assistant","content": "%s"},"logprobs": null,"finish_reason": "stop"}]}` - OpenAIStreamResponseChunk = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4o-mini", "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}` - OpenAIStreamResponseEnd = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4o-mini", "choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}` - OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":%s,"choices":[{"index":0,"message":{"role":"assistant","content":%s},"logprobs":null,"finish_reason":"stop"}]}` + OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":%s,"choices":[{"index":0,"delta":{"role":"assistant","content":%s},"logprobs":null,"finish_reason":null}]}` + OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model": %s,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}` + OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]` TracingPrefix = "trace_span_tag." @@ -168,18 +169,32 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) e return nil } +func generateRandomID() string { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + b := make([]byte, 29) + for i := range b { + b[i] = charset[mrand.Intn(len(charset))] + } + return "chatcmpl-" + string(b) +} + func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action { if !config.checkRequest { + log.Debugf("request checking is disabled") ctx.DontReadRequestBody() } if !config.checkResponse { + log.Debugf("response checking is disabled") ctx.DontReadResponseBody() } return types.ActionContinue } func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { + proxywasm.LogDebugf("checking request body...") content := gjson.GetBytes(body, config.requestContentJsonPath).String() + model := gjson.GetBytes(body, "model").Raw + ctx.SetContext("requestModel", model) if content != "" { timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") randomID, _ := generateHexID(16) @@ -201,7 +216,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] reqParams.Add(k, v) } reqParams.Add("Signature", signature) - config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, + err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) { respData := gjson.GetBytes(responseBody, "Data") if respData.Exists() { @@ -214,13 +229,17 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] if config.denyMessage != "" { proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(config.denyMessage), -1) } else { + answer := respAdvice.Array()[0].Get("Answer").Raw if gjson.GetBytes(body, "stream").Bool() { - jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, respAdvice.Array()[0].Get("Answer").String())) + randomID := generateRandomID() + jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, answer)) proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) } else { - jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, respAdvice.Array()[0].Get("Answer").String())) + randomID := generateRandomID() + jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, answer)) proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) } + ctx.DontReadResponseBody() } } else { proxywasm.ResumeHttpRequest() @@ -230,8 +249,13 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] } }, ) + if err != nil { + log.Errorf("failed call the safe check service: %v", err) + return types.ActionContinue + } return types.ActionPause } else { + proxywasm.LogDebugf("request content is empty. skip") return types.ActionContinue } } @@ -271,8 +295,10 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log } func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { + proxywasm.LogDebugf("checking response body...") hdsMap := ctx.GetContext("headers").(map[string][]string) isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") + model := ctx.GetStringContext("requestModel", "unknown") var content string if isStreamingResponse { content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath) @@ -301,7 +327,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ reqParams.Add(k, v) } reqParams.Add("Signature", signature) - config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, + err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) { defer proxywasm.ResumeHttpResponse() respData := gjson.GetBytes(responseBody, "Data") @@ -314,9 +340,11 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ jsonData = []byte(config.denyMessage) } else { if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") { - jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, respAdvice.Array()[0].Get("Answer").String())) + randomID := generateRandomID() + jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String())) } else { - jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, respAdvice.Array()[0].Get("Answer").String())) + randomID := generateRandomID() + jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String())) } } delete(hdsMap, "content-length") @@ -330,8 +358,13 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ } }, ) + if err != nil { + log.Errorf("failed call the safe check service: %v", err) + return types.ActionContinue + } return types.ActionPause } else { + proxywasm.LogDebugf("request content is empty. skip") return types.ActionContinue } }