mirror of
https://github.com/alibaba/higress.git
synced 2026-06-03 17:47:25 +08:00
[feat] ai-security-guard refactor & support checking multimoadl input (#3075)
This commit is contained in:
@@ -0,0 +1,249 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
ALGORITHM = "ACS3-HMAC-SHA256"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
httpMethod string
|
||||
canonicalUri string
|
||||
host string
|
||||
xAcsAction string
|
||||
xAcsVersion string
|
||||
headers map[string]string
|
||||
body []byte
|
||||
queryParam map[string]interface{}
|
||||
}
|
||||
|
||||
func newRequest(httpMethod, canonicalUri, host, xAcsAction, xAcsVersion string) *Request {
|
||||
req := &Request{
|
||||
httpMethod: httpMethod,
|
||||
canonicalUri: canonicalUri,
|
||||
host: host,
|
||||
xAcsAction: xAcsAction,
|
||||
xAcsVersion: xAcsVersion,
|
||||
headers: make(map[string]string),
|
||||
queryParam: make(map[string]interface{}),
|
||||
}
|
||||
req.headers["host"] = host
|
||||
req.headers["x-acs-action"] = xAcsAction
|
||||
req.headers["x-acs-version"] = xAcsVersion
|
||||
req.headers["x-acs-date"] = time.Now().UTC().Format(time.RFC3339)
|
||||
req.headers["x-acs-signature-nonce"] = uuid.New().String()
|
||||
return req
|
||||
}
|
||||
|
||||
func getAuthorization(req *Request, AccessKeyId, AccessKeySecret, SecurityToken string) {
|
||||
newQueryParams := make(map[string]interface{})
|
||||
processObject(newQueryParams, "", req.queryParam)
|
||||
req.queryParam = newQueryParams
|
||||
canonicalQueryString := ""
|
||||
keys := maps.Keys(req.queryParam)
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
v := req.queryParam[k]
|
||||
canonicalQueryString += percentCode(url.QueryEscape(k)) + "=" + percentCode(url.QueryEscape(fmt.Sprintf("%v", v))) + "&"
|
||||
}
|
||||
canonicalQueryString = strings.TrimSuffix(canonicalQueryString, "&")
|
||||
|
||||
var bodyContent []byte
|
||||
if req.body == nil {
|
||||
bodyContent = []byte("")
|
||||
} else {
|
||||
bodyContent = req.body
|
||||
}
|
||||
hashedRequestPayload := sha256Hex(bodyContent)
|
||||
req.headers["x-acs-content-sha256"] = hashedRequestPayload
|
||||
|
||||
if SecurityToken != "" {
|
||||
req.headers["x-acs-security-token"] = SecurityToken
|
||||
}
|
||||
|
||||
canonicalHeaders := ""
|
||||
signedHeaders := ""
|
||||
HeadersKeys := maps.Keys(req.headers)
|
||||
sort.Strings(HeadersKeys)
|
||||
for _, k := range HeadersKeys {
|
||||
lowerKey := strings.ToLower(k)
|
||||
if lowerKey == "host" || strings.HasPrefix(lowerKey, "x-acs-") || lowerKey == "content-type" {
|
||||
canonicalHeaders += lowerKey + ":" + req.headers[k] + "\n"
|
||||
signedHeaders += lowerKey + ";"
|
||||
}
|
||||
}
|
||||
signedHeaders = strings.TrimSuffix(signedHeaders, ";")
|
||||
|
||||
canonicalRequest := req.httpMethod + "\n" + req.canonicalUri + "\n" + canonicalQueryString + "\n" + canonicalHeaders + "\n" + signedHeaders + "\n" + hashedRequestPayload
|
||||
|
||||
hashedCanonicalRequest := sha256Hex([]byte(canonicalRequest))
|
||||
stringToSign := ALGORITHM + "\n" + hashedCanonicalRequest
|
||||
|
||||
byteData, err := hmac256([]byte(AccessKeySecret), stringToSign)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
panic(err)
|
||||
}
|
||||
signature := strings.ToLower(hex.EncodeToString(byteData))
|
||||
|
||||
authorization := ALGORITHM + " Credential=" + AccessKeyId + ",SignedHeaders=" + signedHeaders + ",Signature=" + signature
|
||||
req.headers["Authorization"] = authorization
|
||||
}
|
||||
|
||||
func hmac256(key []byte, toSignString string) ([]byte, error) {
|
||||
h := hmac.New(sha256.New, key)
|
||||
_, err := h.Write([]byte(toSignString))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h.Sum(nil), nil
|
||||
}
|
||||
|
||||
func sha256Hex(byteArray []byte) string {
|
||||
hash := sha256.New()
|
||||
_, _ = hash.Write(byteArray)
|
||||
hexString := hex.EncodeToString(hash.Sum(nil))
|
||||
return hexString
|
||||
}
|
||||
|
||||
func percentCode(str string) string {
|
||||
str = strings.ReplaceAll(str, "+", "%20")
|
||||
str = strings.ReplaceAll(str, "*", "%2A")
|
||||
str = strings.ReplaceAll(str, "%7E", "~")
|
||||
return str
|
||||
}
|
||||
|
||||
func formDataToString(formData map[string]interface{}) *string {
|
||||
tmp := make(map[string]interface{})
|
||||
processObject(tmp, "", formData)
|
||||
res := ""
|
||||
urlEncoder := url.Values{}
|
||||
for key, value := range tmp {
|
||||
v := fmt.Sprintf("%v", value)
|
||||
urlEncoder.Add(key, v)
|
||||
}
|
||||
res = urlEncoder.Encode()
|
||||
return &res
|
||||
}
|
||||
|
||||
// processObject 递归处理对象,将复杂对象(如Map和List)展开为平面的键值对
|
||||
func processObject(mapResult map[string]interface{}, key string, value interface{}) {
|
||||
if value == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case []interface{}:
|
||||
for i, item := range v {
|
||||
processObject(mapResult, fmt.Sprintf("%s.%d", key, i+1), item)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
for subKey, subValue := range v {
|
||||
processObject(mapResult, fmt.Sprintf("%s.%s", key, subKey), subValue)
|
||||
}
|
||||
default:
|
||||
if strings.HasPrefix(key, ".") {
|
||||
key = key[1:]
|
||||
}
|
||||
if b, ok := v.([]byte); ok {
|
||||
mapResult[key] = string(b)
|
||||
} else {
|
||||
mapResult[key] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateRequestForText(config cfg.AISecurityConfig, checkAction, checkService, text, sessionID string) (path string, headers [][2]string, reqBody []byte) {
|
||||
httpMethod := "POST"
|
||||
canonicalUri := "/"
|
||||
xAcsVersion := "2022-03-02"
|
||||
req := newRequest(httpMethod, canonicalUri, config.Host, checkAction, xAcsVersion)
|
||||
|
||||
req.queryParam["Service"] = checkService
|
||||
|
||||
body := make(map[string]interface{})
|
||||
serviceParameters := make(map[string]interface{})
|
||||
serviceParameters["content"] = text
|
||||
serviceParameters["sessionId"] = sessionID
|
||||
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
|
||||
serviceParametersJSON, _ := json.Marshal(serviceParameters)
|
||||
body["ServiceParameters"] = serviceParametersJSON
|
||||
str := formDataToString(body)
|
||||
req.body = []byte(*str)
|
||||
req.headers["content-type"] = "application/x-www-form-urlencoded"
|
||||
req.headers["User-Agent"] = cfg.AliyunUserAgent
|
||||
|
||||
getAuthorization(req, config.AK, config.SK, config.Token)
|
||||
|
||||
q := url.Values{}
|
||||
keys := maps.Keys(req.queryParam)
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
v := req.queryParam[k]
|
||||
q.Set(k, fmt.Sprintf("%v", v))
|
||||
}
|
||||
for k, v := range req.headers {
|
||||
if k != "host" {
|
||||
headers = append(headers, [2]string{k, v})
|
||||
}
|
||||
}
|
||||
return "?" + q.Encode(), headers, req.body
|
||||
}
|
||||
|
||||
func GenerateRequestForImage(config cfg.AISecurityConfig, checkAction, checkService, imgUrl, imgBase64 string) (path string, headers [][2]string, reqBody []byte) {
|
||||
httpMethod := "POST"
|
||||
canonicalUri := "/"
|
||||
xAcsVersion := "2022-03-02"
|
||||
req := newRequest(httpMethod, canonicalUri, config.Host, checkAction, xAcsVersion)
|
||||
|
||||
req.queryParam["Service"] = checkService
|
||||
|
||||
body := make(map[string]interface{})
|
||||
serviceParameters := make(map[string]interface{})
|
||||
if imgUrl != "" {
|
||||
serviceParameters["imageUrls"] = []string{imgUrl}
|
||||
}
|
||||
serviceParametersJSON, _ := json.Marshal(serviceParameters)
|
||||
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
|
||||
body["ServiceParameters"] = serviceParametersJSON
|
||||
if imgBase64 != "" {
|
||||
body["ImageBase64Str"] = imgBase64
|
||||
}
|
||||
str := formDataToString(body)
|
||||
req.body = []byte(*str)
|
||||
req.headers["content-type"] = "application/x-www-form-urlencoded"
|
||||
req.headers["User-Agent"] = cfg.AliyunUserAgent
|
||||
|
||||
getAuthorization(req, config.AK, config.SK, config.Token)
|
||||
|
||||
q := url.Values{}
|
||||
keys := maps.Keys(req.queryParam)
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
v := req.queryParam[k]
|
||||
q.Set(k, fmt.Sprintf("%v", v))
|
||||
}
|
||||
for k, v := range req.headers {
|
||||
// host will be added by envoy automatically
|
||||
if k != "host" {
|
||||
headers = append(headers, [2]string{k, v})
|
||||
}
|
||||
}
|
||||
return "?" + q.Encode(), headers, req.body
|
||||
}
|
||||
Reference in New Issue
Block a user