diff --git a/plugins/wasm-go/extensions/ai-security-guard/.gitignore b/plugins/wasm-go/extensions/ai-security-guard/.gitignore new file mode 100644 index 000000000..adfb1c0b9 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/.gitignore @@ -0,0 +1,4 @@ +main.wasm +v1/ +v2/ +config.yaml \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-security-guard/README.md b/plugins/wasm-go/extensions/ai-security-guard/README.md new file mode 100644 index 000000000..a7339b1e4 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/README.md @@ -0,0 +1,22 @@ +# 简介 + +# 配置说明 +| Name | Type | Requirement | Default | Description | +| :-: | :-: | :-: | :-: | :-: | +| serviceSource | string | requried | - | 服务来源,填dns | +| serviceName | string | requried | - | 服务名 | +| servicePort | string | requried | - | 服务端口 | +| domain | string | requried | - | 阿里云内容安全endpoint | +| ak | string | requried | - | 阿里云AK | +| sk | string | requried | - | 阿里云SK | + + +# 配置示例 +```yaml +serviceSource: "dns" +serviceName: "safecheck" +servicePort: 443 +domain: "green-cip.cn-shanghai.aliyuncs.com" +ak: "XXXXXXXXX" +sk: "XXXXXXXXXXXXXXX" +``` \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-security-guard/go.mod b/plugins/wasm-go/extensions/ai-security-guard/go.mod new file mode 100644 index 000000000..bdc7ca8bd --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/go.mod @@ -0,0 +1,18 @@ +module myplugin + +go 1.18 + +require ( + github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906 + github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc + github.com/tidwall/gjson v1.14.3 +) + +require ( + github.com/google/uuid v1.3.0 // indirect + github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect + github.com/magefile/mage v1.14.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/resp v0.1.1 // indirect +) diff --git a/plugins/wasm-go/extensions/ai-security-guard/go.sum b/plugins/wasm-go/extensions/ai-security-guard/go.sum new file mode 100644 index 000000000..70cc690d9 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/go.sum @@ -0,0 +1,24 @@ +github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY= +github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc= +github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906 h1:RhEmB+ApLKsClZD7joTC4ifmsVgOVz4pFLdPR3xhNaE= +github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906/go.mod h1:10jQXKsYFUF7djs+Oy7t82f4dbie9pISfP9FJwpPLuk= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= +github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= +github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= +github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= +github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go new file mode 100644 index 000000000..cbe4d32a3 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -0,0 +1,272 @@ +package main + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "sort" + "strings" + "time" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/tidwall/gjson" +) + +func main() { + wrapper.SetCtx( + "ai-security-guard", + wrapper.ParseConfigBy(parseConfig), + wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), + wrapper.ProcessRequestBodyBy(onHttpRequestBody), + wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders), + wrapper.ProcessResponseBodyBy(onHttpResponseBody), + ) +} + +type AISecurityConfig struct { + client wrapper.HttpClient + ak string + sk string +} + +type StandardResponse struct { + Code int `json:"Code"` + Phase string `json:"BlockPhase"` + Message string `json:"Message"` +} + +func urlEncoding(rawStr string) string { + encodedStr := url.PathEscape(rawStr) + encodedStr = strings.ReplaceAll(encodedStr, "+", "%20") + encodedStr = strings.ReplaceAll(encodedStr, ":", "%3A") + encodedStr = strings.ReplaceAll(encodedStr, "=", "%3D") + encodedStr = strings.ReplaceAll(encodedStr, "&", "%26") + return encodedStr +} + +func hmacSha1(message, secret string) string { + key := []byte(secret) + h := hmac.New(sha1.New, key) + h.Write([]byte(message)) + hash := h.Sum(nil) + return base64.StdEncoding.EncodeToString(hash) +} + +func getSign(params map[string]string, secret string) string { + paramArray := []string{} + for k, v := range params { + paramArray = append(paramArray, urlEncoding(k)+"="+urlEncoding(v)) + } + sort.Slice(paramArray, func(i, j int) bool { + return paramArray[i] <= paramArray[j] + }) + canonicalStr := strings.Join(paramArray, "&") + signStr := "POST&%2F&" + urlEncoding(canonicalStr) + fmt.Println(signStr) + return hmacSha1(signStr, secret) +} + +func generateHexID(length int) (string, error) { + bytes := make([]byte, length/2) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return hex.EncodeToString(bytes), nil +} + +func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) error { + serviceName := json.Get("serviceName").String() + servicePort := json.Get("servicePort").Int() + domain := json.Get("domain").String() + config.ak = json.Get("ak").String() + config.sk = json.Get("sk").String() + if serviceName == "" || servicePort == 0 || domain == "" { + return errors.New("invalid service config") + } + config.client = wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: serviceName, + Port: servicePort, + Domain: domain, + }) + return nil +} + +func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action { + return types.ActionContinue +} + +func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { + messages := gjson.GetBytes(body, "messages").Array() + if len(messages) > 0 { + role := messages[len(messages)-1].Get("role").String() + content := messages[len(messages)-1].Get("content").String() + if role != "user" { + return types.ActionContinue + } + timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") + randomID, _ := generateHexID(16) + params := map[string]string{ + "Format": "JSON", + "Version": "2022-03-02", + "SignatureMethod": "Hmac-SHA1", + "SignatureNonce": randomID, + "SignatureVersion": "1.0", + "Action": "TextModerationPlus", + "AccessKeyId": config.ak, + "Timestamp": timestamp, + "Service": "llm_query_moderation", + "ServiceParameters": `{"content": "` + content + `"}`, + } + signature := getSign(params, config.sk+"&") + reqParams := url.Values{} + for k, v := range params { + reqParams.Add(k, v) + } + reqParams.Add("Signature", signature) + config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + respData := gjson.GetBytes(responseBody, "Data") + if respData.Exists() { + respAdvice := respData.Get("Advice") + respResult := respData.Get("Result") + if respAdvice.Exists() { + sr := StandardResponse{ + Code: 403, + Phase: "Request", + Message: respAdvice.Array()[0].Get("Answer").String(), + } + jsonData, _ := json.MarshalIndent(sr, "", " ") + proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String())) + proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, jsonData, -1) + } else if respResult.Array()[0].Get("Label").String() != "nonLabel" { + sr := StandardResponse{ + Code: 403, + Phase: "Request", + Message: "risk detected", + } + jsonData, _ := json.MarshalIndent(sr, "", " ") + proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String())) + proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, jsonData, -1) + } else { + proxywasm.ResumeHttpRequest() + } + } else { + proxywasm.ResumeHttpRequest() + } + }, + ) + return types.ActionPause + } else { + return types.ActionContinue + } +} + +func convertHeaders(hs [][2]string) map[string][]string { + ret := make(map[string][]string) + for _, h := range hs { + k, v := strings.ToLower(h[0]), h[1] + ret[k] = append(ret[k], v) + } + return ret +} + +// headers: map[string][]string -> [][2]string +func reconvertHeaders(hs map[string][]string) [][2]string { + var ret [][2]string + for k, vs := range hs { + for _, v := range vs { + ret = append(ret, [2]string{k, v}) + } + } + sort.SliceStable(ret, func(i, j int) bool { + return ret[i][0] < ret[j][0] + }) + return ret +} + +func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action { + headers, err := proxywasm.GetHttpResponseHeaders() + if err != nil { + log.Warnf("failed to get response headers: %v", err) + return types.ActionContinue + } + hdsMap := convertHeaders(headers) + ctx.SetContext("headers", hdsMap) + return types.HeaderStopIteration +} + +func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { + messages := gjson.GetBytes(body, "choices").Array() + if len(messages) > 0 { + content := messages[0].Get("message").Get("content").String() + timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") + randomID, _ := generateHexID(16) + params := map[string]string{ + "Format": "JSON", + "Version": "2022-03-02", + "SignatureMethod": "Hmac-SHA1", + "SignatureNonce": randomID, + "SignatureVersion": "1.0", + "Action": "TextModerationPlus", + "AccessKeyId": config.ak, + "Timestamp": timestamp, + "Service": "llm_response_moderation", + "ServiceParameters": `{"content": "` + content + `"}`, + } + signature := getSign(params, config.sk+"&") + reqParams := url.Values{} + for k, v := range params { + reqParams.Add(k, v) + } + reqParams.Add("Signature", signature) + config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + defer proxywasm.ResumeHttpResponse() + respData := gjson.GetBytes(responseBody, "Data") + if respData.Exists() { + respAdvice := respData.Get("Advice") + respResult := respData.Get("Result") + if respAdvice.Exists() { + sr := StandardResponse{ + Code: 403, + Phase: "Response", + Message: respAdvice.Array()[0].Get("Answer").String(), + } + jsonData, _ := json.MarshalIndent(sr, "", " ") + hdsMap := ctx.GetContext("headers").(map[string][]string) + delete(hdsMap, "content-length") + hdsMap[":status"] = []string{"403"} + proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap)) + proxywasm.ReplaceHttpResponseBody(jsonData) + proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String())) + } else if respResult.Array()[0].Get("Label").String() != "nonLabel" { + sr := StandardResponse{ + Code: 403, + Phase: "Response", + Message: "risk detected", + } + jsonData, _ := json.MarshalIndent(sr, "", " ") + hdsMap := ctx.GetContext("headers").(map[string][]string) + delete(hdsMap, "content-length") + hdsMap[":status"] = []string{"403"} + proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap)) + proxywasm.ReplaceHttpResponseBody(jsonData) + proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String())) + } + } + }, + ) + return types.ActionPause + } else { + return types.ActionContinue + } +}