mirror of
https://github.com/alibaba/higress.git
synced 2026-05-27 22:27:29 +08:00
add plugin: ai-security-guard (#1034)
This commit is contained in:
4
plugins/wasm-go/extensions/ai-security-guard/.gitignore
vendored
Normal file
4
plugins/wasm-go/extensions/ai-security-guard/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
main.wasm
|
||||||
|
v1/
|
||||||
|
v2/
|
||||||
|
config.yaml
|
||||||
22
plugins/wasm-go/extensions/ai-security-guard/README.md
Normal file
22
plugins/wasm-go/extensions/ai-security-guard/README.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# 简介
|
||||||
|
|
||||||
|
# 配置说明
|
||||||
|
| Name | Type | Requirement | Default | Description |
|
||||||
|
| :-: | :-: | :-: | :-: | :-: |
|
||||||
|
| serviceSource | string | requried | - | 服务来源,填dns |
|
||||||
|
| serviceName | string | requried | - | 服务名 |
|
||||||
|
| servicePort | string | requried | - | 服务端口 |
|
||||||
|
| domain | string | requried | - | 阿里云内容安全endpoint |
|
||||||
|
| ak | string | requried | - | 阿里云AK |
|
||||||
|
| sk | string | requried | - | 阿里云SK |
|
||||||
|
|
||||||
|
|
||||||
|
# 配置示例
|
||||||
|
```yaml
|
||||||
|
serviceSource: "dns"
|
||||||
|
serviceName: "safecheck"
|
||||||
|
servicePort: 443
|
||||||
|
domain: "green-cip.cn-shanghai.aliyuncs.com"
|
||||||
|
ak: "XXXXXXXXX"
|
||||||
|
sk: "XXXXXXXXXXXXXXX"
|
||||||
|
```
|
||||||
18
plugins/wasm-go/extensions/ai-security-guard/go.mod
Normal file
18
plugins/wasm-go/extensions/ai-security-guard/go.mod
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
module myplugin
|
||||||
|
|
||||||
|
go 1.18
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906
|
||||||
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc
|
||||||
|
github.com/tidwall/gjson v1.14.3
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/google/uuid v1.3.0 // indirect
|
||||||
|
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||||
|
github.com/magefile/mage v1.14.0 // indirect
|
||||||
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
|
github.com/tidwall/resp v0.1.1 // indirect
|
||||||
|
)
|
||||||
24
plugins/wasm-go/extensions/ai-security-guard/go.sum
Normal file
24
plugins/wasm-go/extensions/ai-security-guard/go.sum
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
|
||||||
|
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
|
||||||
|
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906 h1:RhEmB+ApLKsClZD7joTC4ifmsVgOVz4pFLdPR3xhNaE=
|
||||||
|
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906/go.mod h1:10jQXKsYFUF7djs+Oy7t82f4dbie9pISfP9FJwpPLuk=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||||
|
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||||
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
|
||||||
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||||
|
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||||
|
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||||
|
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||||
|
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||||
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
|
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||||
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||||
|
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
272
plugins/wasm-go/extensions/ai-security-guard/main.go
Normal file
272
plugins/wasm-go/extensions/ai-security-guard/main.go
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
wrapper.SetCtx(
|
||||||
|
"ai-security-guard",
|
||||||
|
wrapper.ParseConfigBy(parseConfig),
|
||||||
|
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||||
|
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||||
|
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||||||
|
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
type AISecurityConfig struct {
|
||||||
|
client wrapper.HttpClient
|
||||||
|
ak string
|
||||||
|
sk string
|
||||||
|
}
|
||||||
|
|
||||||
|
type StandardResponse struct {
|
||||||
|
Code int `json:"Code"`
|
||||||
|
Phase string `json:"BlockPhase"`
|
||||||
|
Message string `json:"Message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func urlEncoding(rawStr string) string {
|
||||||
|
encodedStr := url.PathEscape(rawStr)
|
||||||
|
encodedStr = strings.ReplaceAll(encodedStr, "+", "%20")
|
||||||
|
encodedStr = strings.ReplaceAll(encodedStr, ":", "%3A")
|
||||||
|
encodedStr = strings.ReplaceAll(encodedStr, "=", "%3D")
|
||||||
|
encodedStr = strings.ReplaceAll(encodedStr, "&", "%26")
|
||||||
|
return encodedStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func hmacSha1(message, secret string) string {
|
||||||
|
key := []byte(secret)
|
||||||
|
h := hmac.New(sha1.New, key)
|
||||||
|
h.Write([]byte(message))
|
||||||
|
hash := h.Sum(nil)
|
||||||
|
return base64.StdEncoding.EncodeToString(hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSign(params map[string]string, secret string) string {
|
||||||
|
paramArray := []string{}
|
||||||
|
for k, v := range params {
|
||||||
|
paramArray = append(paramArray, urlEncoding(k)+"="+urlEncoding(v))
|
||||||
|
}
|
||||||
|
sort.Slice(paramArray, func(i, j int) bool {
|
||||||
|
return paramArray[i] <= paramArray[j]
|
||||||
|
})
|
||||||
|
canonicalStr := strings.Join(paramArray, "&")
|
||||||
|
signStr := "POST&%2F&" + urlEncoding(canonicalStr)
|
||||||
|
fmt.Println(signStr)
|
||||||
|
return hmacSha1(signStr, secret)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateHexID(length int) (string, error) {
|
||||||
|
bytes := make([]byte, length/2)
|
||||||
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) error {
|
||||||
|
serviceName := json.Get("serviceName").String()
|
||||||
|
servicePort := json.Get("servicePort").Int()
|
||||||
|
domain := json.Get("domain").String()
|
||||||
|
config.ak = json.Get("ak").String()
|
||||||
|
config.sk = json.Get("sk").String()
|
||||||
|
if serviceName == "" || servicePort == 0 || domain == "" {
|
||||||
|
return errors.New("invalid service config")
|
||||||
|
}
|
||||||
|
config.client = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||||
|
ServiceName: serviceName,
|
||||||
|
Port: servicePort,
|
||||||
|
Domain: domain,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
|
||||||
|
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
||||||
|
messages := gjson.GetBytes(body, "messages").Array()
|
||||||
|
if len(messages) > 0 {
|
||||||
|
role := messages[len(messages)-1].Get("role").String()
|
||||||
|
content := messages[len(messages)-1].Get("content").String()
|
||||||
|
if role != "user" {
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||||
|
randomID, _ := generateHexID(16)
|
||||||
|
params := map[string]string{
|
||||||
|
"Format": "JSON",
|
||||||
|
"Version": "2022-03-02",
|
||||||
|
"SignatureMethod": "Hmac-SHA1",
|
||||||
|
"SignatureNonce": randomID,
|
||||||
|
"SignatureVersion": "1.0",
|
||||||
|
"Action": "TextModerationPlus",
|
||||||
|
"AccessKeyId": config.ak,
|
||||||
|
"Timestamp": timestamp,
|
||||||
|
"Service": "llm_query_moderation",
|
||||||
|
"ServiceParameters": `{"content": "` + content + `"}`,
|
||||||
|
}
|
||||||
|
signature := getSign(params, config.sk+"&")
|
||||||
|
reqParams := url.Values{}
|
||||||
|
for k, v := range params {
|
||||||
|
reqParams.Add(k, v)
|
||||||
|
}
|
||||||
|
reqParams.Add("Signature", signature)
|
||||||
|
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
respData := gjson.GetBytes(responseBody, "Data")
|
||||||
|
if respData.Exists() {
|
||||||
|
respAdvice := respData.Get("Advice")
|
||||||
|
respResult := respData.Get("Result")
|
||||||
|
if respAdvice.Exists() {
|
||||||
|
sr := StandardResponse{
|
||||||
|
Code: 403,
|
||||||
|
Phase: "Request",
|
||||||
|
Message: respAdvice.Array()[0].Get("Answer").String(),
|
||||||
|
}
|
||||||
|
jsonData, _ := json.MarshalIndent(sr, "", " ")
|
||||||
|
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||||
|
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||||
|
} else if respResult.Array()[0].Get("Label").String() != "nonLabel" {
|
||||||
|
sr := StandardResponse{
|
||||||
|
Code: 403,
|
||||||
|
Phase: "Request",
|
||||||
|
Message: "risk detected",
|
||||||
|
}
|
||||||
|
jsonData, _ := json.MarshalIndent(sr, "", " ")
|
||||||
|
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||||
|
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||||
|
} else {
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return types.ActionPause
|
||||||
|
} else {
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertHeaders(hs [][2]string) map[string][]string {
|
||||||
|
ret := make(map[string][]string)
|
||||||
|
for _, h := range hs {
|
||||||
|
k, v := strings.ToLower(h[0]), h[1]
|
||||||
|
ret[k] = append(ret[k], v)
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
// headers: map[string][]string -> [][2]string
|
||||||
|
func reconvertHeaders(hs map[string][]string) [][2]string {
|
||||||
|
var ret [][2]string
|
||||||
|
for k, vs := range hs {
|
||||||
|
for _, v := range vs {
|
||||||
|
ret = append(ret, [2]string{k, v})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.SliceStable(ret, func(i, j int) bool {
|
||||||
|
return ret[i][0] < ret[j][0]
|
||||||
|
})
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
|
||||||
|
headers, err := proxywasm.GetHttpResponseHeaders()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to get response headers: %v", err)
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
hdsMap := convertHeaders(headers)
|
||||||
|
ctx.SetContext("headers", hdsMap)
|
||||||
|
return types.HeaderStopIteration
|
||||||
|
}
|
||||||
|
|
||||||
|
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
||||||
|
messages := gjson.GetBytes(body, "choices").Array()
|
||||||
|
if len(messages) > 0 {
|
||||||
|
content := messages[0].Get("message").Get("content").String()
|
||||||
|
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||||
|
randomID, _ := generateHexID(16)
|
||||||
|
params := map[string]string{
|
||||||
|
"Format": "JSON",
|
||||||
|
"Version": "2022-03-02",
|
||||||
|
"SignatureMethod": "Hmac-SHA1",
|
||||||
|
"SignatureNonce": randomID,
|
||||||
|
"SignatureVersion": "1.0",
|
||||||
|
"Action": "TextModerationPlus",
|
||||||
|
"AccessKeyId": config.ak,
|
||||||
|
"Timestamp": timestamp,
|
||||||
|
"Service": "llm_response_moderation",
|
||||||
|
"ServiceParameters": `{"content": "` + content + `"}`,
|
||||||
|
}
|
||||||
|
signature := getSign(params, config.sk+"&")
|
||||||
|
reqParams := url.Values{}
|
||||||
|
for k, v := range params {
|
||||||
|
reqParams.Add(k, v)
|
||||||
|
}
|
||||||
|
reqParams.Add("Signature", signature)
|
||||||
|
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
defer proxywasm.ResumeHttpResponse()
|
||||||
|
respData := gjson.GetBytes(responseBody, "Data")
|
||||||
|
if respData.Exists() {
|
||||||
|
respAdvice := respData.Get("Advice")
|
||||||
|
respResult := respData.Get("Result")
|
||||||
|
if respAdvice.Exists() {
|
||||||
|
sr := StandardResponse{
|
||||||
|
Code: 403,
|
||||||
|
Phase: "Response",
|
||||||
|
Message: respAdvice.Array()[0].Get("Answer").String(),
|
||||||
|
}
|
||||||
|
jsonData, _ := json.MarshalIndent(sr, "", " ")
|
||||||
|
hdsMap := ctx.GetContext("headers").(map[string][]string)
|
||||||
|
delete(hdsMap, "content-length")
|
||||||
|
hdsMap[":status"] = []string{"403"}
|
||||||
|
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
|
||||||
|
proxywasm.ReplaceHttpResponseBody(jsonData)
|
||||||
|
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||||
|
} else if respResult.Array()[0].Get("Label").String() != "nonLabel" {
|
||||||
|
sr := StandardResponse{
|
||||||
|
Code: 403,
|
||||||
|
Phase: "Response",
|
||||||
|
Message: "risk detected",
|
||||||
|
}
|
||||||
|
jsonData, _ := json.MarshalIndent(sr, "", " ")
|
||||||
|
hdsMap := ctx.GetContext("headers").(map[string][]string)
|
||||||
|
delete(hdsMap, "content-length")
|
||||||
|
hdsMap[":status"] = []string{"403"}
|
||||||
|
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
|
||||||
|
proxywasm.ReplaceHttpResponseBody(jsonData)
|
||||||
|
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return types.ActionPause
|
||||||
|
} else {
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user