Files
higress/plugins/wasm-go/extensions/ai-security-guard/main.go

442 lines
15 KiB
Go

package main
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
mrand "math/rand"
"net/http"
"net/url"
"sort"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
)
func main() {
wrapper.SetCtx(
"ai-security-guard",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
)
}
const (
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":%s,"choices":[{"index":0,"message":{"role":"assistant","content":%s},"logprobs":null,"finish_reason":"stop"}]}`
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":%s,"choices":[{"index":0,"delta":{"role":"assistant","content":%s},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model": %s,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
TracingPrefix = "trace_span_tag."
DefaultRequestCheckService = "llm_query_moderation"
DefaultResponseCheckService = "llm_response_moderation"
DefaultRequestJsonPath = "messages.@reverse.0.content"
DefaultResponseJsonPath = "choices.0.message.content"
DefaultStreamingResponseJsonPath = "choices.0.delta.content"
DefaultDenyCode = 200
DefaultDenyMessage = "很抱歉,我无法回答您的问题"
AliyunUserAgent = "CIPFrom/AIGateway"
)
type AISecurityConfig struct {
client wrapper.HttpClient
ak string
sk string
checkRequest bool
requestCheckService string
requestContentJsonPath string
checkResponse bool
responseCheckService string
responseContentJsonPath string
responseStreamContentJsonPath string
denyCode int64
denyMessage string
protocolOriginal bool
metrics map[string]proxywasm.MetricCounter
}
func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) {
counter, ok := config.metrics[metricName]
if !ok {
counter = proxywasm.DefineCounterMetric(metricName)
config.metrics[metricName] = counter
}
counter.Increment(inc)
}
func urlEncoding(rawStr string) string {
encodedStr := url.PathEscape(rawStr)
encodedStr = strings.ReplaceAll(encodedStr, "+", "%20")
encodedStr = strings.ReplaceAll(encodedStr, ":", "%3A")
encodedStr = strings.ReplaceAll(encodedStr, "=", "%3D")
encodedStr = strings.ReplaceAll(encodedStr, "&", "%26")
return encodedStr
}
func hmacSha1(message, secret string) string {
key := []byte(secret)
h := hmac.New(sha1.New, key)
h.Write([]byte(message))
hash := h.Sum(nil)
return base64.StdEncoding.EncodeToString(hash)
}
func getSign(params map[string]string, secret string) string {
paramArray := []string{}
for k, v := range params {
paramArray = append(paramArray, urlEncoding(k)+"="+urlEncoding(v))
}
sort.Slice(paramArray, func(i, j int) bool {
return paramArray[i] <= paramArray[j]
})
canonicalStr := strings.Join(paramArray, "&")
signStr := "POST&%2F&" + urlEncoding(canonicalStr)
// proxywasm.LogInfo(signStr)
return hmacSha1(signStr, secret)
}
func generateHexID(length int) (string, error) {
bytes := make([]byte, length/2)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) error {
serviceName := json.Get("serviceName").String()
servicePort := json.Get("servicePort").Int()
serviceHost := json.Get("serviceHost").String()
if serviceName == "" || servicePort == 0 || serviceHost == "" {
return errors.New("invalid service config")
}
config.ak = json.Get("accessKey").String()
config.sk = json.Get("secretKey").String()
if config.ak == "" || config.sk == "" {
return errors.New("invalid AK/SK config")
}
config.checkRequest = json.Get("checkRequest").Bool()
config.checkResponse = json.Get("checkResponse").Bool()
config.protocolOriginal = json.Get("protocol").String() == "original"
config.denyMessage = json.Get("denyMessage").String()
if obj := json.Get("denyCode"); obj.Exists() {
config.denyCode = obj.Int()
} else {
config.denyCode = DefaultDenyCode
}
if obj := json.Get("requestCheckService"); obj.Exists() {
config.requestCheckService = obj.String()
} else {
config.requestCheckService = DefaultRequestCheckService
}
if obj := json.Get("responseCheckService"); obj.Exists() {
config.responseCheckService = obj.String()
} else {
config.responseCheckService = DefaultResponseCheckService
}
if obj := json.Get("requestContentJsonPath"); obj.Exists() {
config.requestContentJsonPath = obj.String()
} else {
config.requestContentJsonPath = DefaultRequestJsonPath
}
if obj := json.Get("responseContentJsonPath"); obj.Exists() {
config.responseContentJsonPath = obj.String()
} else {
config.responseContentJsonPath = DefaultResponseJsonPath
}
if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() {
config.responseStreamContentJsonPath = obj.String()
} else {
config.responseStreamContentJsonPath = DefaultStreamingResponseJsonPath
}
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: serviceName,
Port: servicePort,
Host: serviceHost,
})
config.metrics = make(map[string]proxywasm.MetricCounter)
return nil
}
func generateRandomID() string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, 29)
for i := range b {
b[i] = charset[mrand.Intn(len(charset))]
}
return "chatcmpl-" + string(b)
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
if !config.checkRequest {
log.Debugf("request checking is disabled")
ctx.DontReadRequestBody()
}
if !config.checkResponse {
log.Debugf("response checking is disabled")
ctx.DontReadResponseBody()
}
return types.ActionContinue
}
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
log.Debugf("checking request body...")
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
model := gjson.GetBytes(body, "model").Raw
ctx.SetContext("requestModel", model)
if content != "" {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
params := map[string]string{
"Format": "JSON",
"Version": "2022-03-02",
"SignatureMethod": "Hmac-SHA1",
"SignatureNonce": randomID,
"SignatureVersion": "1.0",
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.requestCheckService,
"ServiceParameters": `{"content": "` + content + `"}`,
}
signature := getSign(params, config.sk+"&")
reqParams := url.Values{}
for k, v := range params {
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != 200 {
log.Error(string(responseBody))
proxywasm.ResumeHttpRequest()
return
}
respData := gjson.GetBytes(responseBody, "Data")
if respData.Exists() {
respAdvice := respData.Get("Advice")
respResult := respData.Get("Result")
var denyMessage string
messageNeedSerialization := true
if config.protocolOriginal {
// not openai
if config.denyMessage != "" {
denyMessage = config.denyMessage
} else if respAdvice.Exists() {
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
messageNeedSerialization = false
} else {
denyMessage = DefaultDenyMessage
}
} else {
// openai
if respAdvice.Exists() {
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
messageNeedSerialization = false
} else if config.denyMessage != "" {
denyMessage = config.denyMessage
} else {
denyMessage = DefaultDenyMessage
}
}
if messageNeedSerialization {
if data, err := json.Marshal(denyMessage); err == nil {
denyMessage = string(data)
} else {
denyMessage = fmt.Sprintf("\"%s\"", DefaultDenyMessage)
}
}
if respResult.Array()[0].Get("Label").String() != "nonLabel" {
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request"))
config.incrementCounter("ai_sec_request_deny", 1)
if config.protocolOriginal {
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(denyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, denyMessage))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
} else {
proxywasm.ResumeHttpRequest()
}
} else {
proxywasm.ResumeHttpRequest()
}
},
)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
return types.ActionContinue
}
return types.ActionPause
} else {
log.Debugf("request content is empty. skip")
return types.ActionContinue
}
}
func convertHeaders(hs [][2]string) map[string][]string {
ret := make(map[string][]string)
for _, h := range hs {
k, v := strings.ToLower(h[0]), h[1]
ret[k] = append(ret[k], v)
}
return ret
}
// headers: map[string][]string -> [][2]string
func reconvertHeaders(hs map[string][]string) [][2]string {
var ret [][2]string
for k, vs := range hs {
for _, v := range vs {
ret = append(ret, [2]string{k, v})
}
}
sort.SliceStable(ret, func(i, j int) bool {
return ret[i][0] < ret[j][0]
})
return ret
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
headers, err := proxywasm.GetHttpResponseHeaders()
if err != nil {
log.Warnf("failed to get response headers: %v", err)
return types.ActionContinue
}
hdsMap := convertHeaders(headers)
ctx.SetContext("headers", hdsMap)
return types.HeaderStopIteration
}
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
log.Debugf("checking response body...")
hdsMap := ctx.GetContext("headers").(map[string][]string)
isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream")
model := ctx.GetStringContext("requestModel", "unknown")
var content string
if isStreamingResponse {
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
} else {
content = gjson.GetBytes(body, config.responseContentJsonPath).String()
}
log.Debugf("Raw response content is: %s", content)
if len(content) > 0 {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
params := map[string]string{
"Format": "JSON",
"Version": "2022-03-02",
"SignatureMethod": "Hmac-SHA1",
"SignatureNonce": randomID,
"SignatureVersion": "1.0",
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.responseCheckService,
"ServiceParameters": `{"content": "` + content + `"}`,
}
signature := getSign(params, config.sk+"&")
reqParams := url.Values{}
for k, v := range params {
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
defer proxywasm.ResumeHttpResponse()
if statusCode != 200 {
log.Error(string(responseBody))
return
}
respData := gjson.GetBytes(responseBody, "Data")
if respData.Exists() {
respAdvice := respData.Get("Advice")
respResult := respData.Get("Result")
var denyMessage string
if config.protocolOriginal {
// not openai
if config.denyMessage != "" {
denyMessage = config.denyMessage
} else if respAdvice.Exists() {
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
} else {
denyMessage = DefaultDenyMessage
}
} else {
// openai
if respAdvice.Exists() {
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
} else if config.denyMessage != "" {
denyMessage = config.denyMessage
} else {
denyMessage = DefaultDenyMessage
}
}
if respResult.Array()[0].Get("Label").String() != "nonLabel" {
var jsonData []byte
if config.protocolOriginal {
jsonData = []byte(denyMessage)
} else if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String(), randomID, model))
} else {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String()))
}
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
proxywasm.ReplaceHttpResponseBody(jsonData)
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("response"))
config.incrementCounter("ai_sec_response_deny", 1)
}
}
},
)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
return types.ActionContinue
}
return types.ActionPause
} else {
log.Debugf("request content is empty. skip")
return types.ActionContinue
}
}
func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
strChunks := []string{}
for _, chunk := range chunks {
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
jsonObj := gjson.GetBytes(chunk, jsonPath)
if jsonObj.Exists() {
strChunks = append(strChunks, jsonObj.String())
}
}
return strings.Join(strChunks, "")
}