mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 23:21:08 +08:00
250 lines
7.0 KiB
Go
250 lines
7.0 KiB
Go
package common
|
||
|
||
import (
|
||
"crypto/hmac"
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"sort"
|
||
|
||
"golang.org/x/exp/maps"
|
||
|
||
"fmt"
|
||
"net/url"
|
||
"strings"
|
||
"time"
|
||
|
||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||
"github.com/google/uuid"
|
||
)
|
||
|
||
const (
|
||
ALGORITHM = "ACS3-HMAC-SHA256"
|
||
)
|
||
|
||
type Request struct {
|
||
httpMethod string
|
||
canonicalUri string
|
||
host string
|
||
xAcsAction string
|
||
xAcsVersion string
|
||
headers map[string]string
|
||
body []byte
|
||
queryParam map[string]interface{}
|
||
}
|
||
|
||
func newRequest(httpMethod, canonicalUri, host, xAcsAction, xAcsVersion string) *Request {
|
||
req := &Request{
|
||
httpMethod: httpMethod,
|
||
canonicalUri: canonicalUri,
|
||
host: host,
|
||
xAcsAction: xAcsAction,
|
||
xAcsVersion: xAcsVersion,
|
||
headers: make(map[string]string),
|
||
queryParam: make(map[string]interface{}),
|
||
}
|
||
req.headers["host"] = host
|
||
req.headers["x-acs-action"] = xAcsAction
|
||
req.headers["x-acs-version"] = xAcsVersion
|
||
req.headers["x-acs-date"] = time.Now().UTC().Format(time.RFC3339)
|
||
req.headers["x-acs-signature-nonce"] = uuid.New().String()
|
||
return req
|
||
}
|
||
|
||
func getAuthorization(req *Request, AccessKeyId, AccessKeySecret, SecurityToken string) {
|
||
newQueryParams := make(map[string]interface{})
|
||
processObject(newQueryParams, "", req.queryParam)
|
||
req.queryParam = newQueryParams
|
||
canonicalQueryString := ""
|
||
keys := maps.Keys(req.queryParam)
|
||
sort.Strings(keys)
|
||
for _, k := range keys {
|
||
v := req.queryParam[k]
|
||
canonicalQueryString += percentCode(url.QueryEscape(k)) + "=" + percentCode(url.QueryEscape(fmt.Sprintf("%v", v))) + "&"
|
||
}
|
||
canonicalQueryString = strings.TrimSuffix(canonicalQueryString, "&")
|
||
|
||
var bodyContent []byte
|
||
if req.body == nil {
|
||
bodyContent = []byte("")
|
||
} else {
|
||
bodyContent = req.body
|
||
}
|
||
hashedRequestPayload := sha256Hex(bodyContent)
|
||
req.headers["x-acs-content-sha256"] = hashedRequestPayload
|
||
|
||
if SecurityToken != "" {
|
||
req.headers["x-acs-security-token"] = SecurityToken
|
||
}
|
||
|
||
canonicalHeaders := ""
|
||
signedHeaders := ""
|
||
HeadersKeys := maps.Keys(req.headers)
|
||
sort.Strings(HeadersKeys)
|
||
for _, k := range HeadersKeys {
|
||
lowerKey := strings.ToLower(k)
|
||
if lowerKey == "host" || strings.HasPrefix(lowerKey, "x-acs-") || lowerKey == "content-type" {
|
||
canonicalHeaders += lowerKey + ":" + req.headers[k] + "\n"
|
||
signedHeaders += lowerKey + ";"
|
||
}
|
||
}
|
||
signedHeaders = strings.TrimSuffix(signedHeaders, ";")
|
||
|
||
canonicalRequest := req.httpMethod + "\n" + req.canonicalUri + "\n" + canonicalQueryString + "\n" + canonicalHeaders + "\n" + signedHeaders + "\n" + hashedRequestPayload
|
||
|
||
hashedCanonicalRequest := sha256Hex([]byte(canonicalRequest))
|
||
stringToSign := ALGORITHM + "\n" + hashedCanonicalRequest
|
||
|
||
byteData, err := hmac256([]byte(AccessKeySecret), stringToSign)
|
||
if err != nil {
|
||
fmt.Println(err)
|
||
panic(err)
|
||
}
|
||
signature := strings.ToLower(hex.EncodeToString(byteData))
|
||
|
||
authorization := ALGORITHM + " Credential=" + AccessKeyId + ",SignedHeaders=" + signedHeaders + ",Signature=" + signature
|
||
req.headers["Authorization"] = authorization
|
||
}
|
||
|
||
func hmac256(key []byte, toSignString string) ([]byte, error) {
|
||
h := hmac.New(sha256.New, key)
|
||
_, err := h.Write([]byte(toSignString))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return h.Sum(nil), nil
|
||
}
|
||
|
||
func sha256Hex(byteArray []byte) string {
|
||
hash := sha256.New()
|
||
_, _ = hash.Write(byteArray)
|
||
hexString := hex.EncodeToString(hash.Sum(nil))
|
||
return hexString
|
||
}
|
||
|
||
func percentCode(str string) string {
|
||
str = strings.ReplaceAll(str, "+", "%20")
|
||
str = strings.ReplaceAll(str, "*", "%2A")
|
||
str = strings.ReplaceAll(str, "%7E", "~")
|
||
return str
|
||
}
|
||
|
||
func formDataToString(formData map[string]interface{}) *string {
|
||
tmp := make(map[string]interface{})
|
||
processObject(tmp, "", formData)
|
||
res := ""
|
||
urlEncoder := url.Values{}
|
||
for key, value := range tmp {
|
||
v := fmt.Sprintf("%v", value)
|
||
urlEncoder.Add(key, v)
|
||
}
|
||
res = urlEncoder.Encode()
|
||
return &res
|
||
}
|
||
|
||
// processObject 递归处理对象,将复杂对象(如Map和List)展开为平面的键值对
|
||
func processObject(mapResult map[string]interface{}, key string, value interface{}) {
|
||
if value == nil {
|
||
return
|
||
}
|
||
|
||
switch v := value.(type) {
|
||
case []interface{}:
|
||
for i, item := range v {
|
||
processObject(mapResult, fmt.Sprintf("%s.%d", key, i+1), item)
|
||
}
|
||
case map[string]interface{}:
|
||
for subKey, subValue := range v {
|
||
processObject(mapResult, fmt.Sprintf("%s.%s", key, subKey), subValue)
|
||
}
|
||
default:
|
||
if strings.HasPrefix(key, ".") {
|
||
key = key[1:]
|
||
}
|
||
if b, ok := v.([]byte); ok {
|
||
mapResult[key] = string(b)
|
||
} else {
|
||
mapResult[key] = fmt.Sprintf("%v", v)
|
||
}
|
||
}
|
||
}
|
||
|
||
func GenerateRequestForText(config cfg.AISecurityConfig, checkAction, checkService, text, sessionID string) (path string, headers [][2]string, reqBody []byte) {
|
||
httpMethod := "POST"
|
||
canonicalUri := "/"
|
||
xAcsVersion := "2022-03-02"
|
||
req := newRequest(httpMethod, canonicalUri, config.Host, checkAction, xAcsVersion)
|
||
|
||
req.queryParam["Service"] = checkService
|
||
|
||
body := make(map[string]interface{})
|
||
serviceParameters := make(map[string]interface{})
|
||
serviceParameters["content"] = text
|
||
serviceParameters["sessionId"] = sessionID
|
||
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
|
||
serviceParametersJSON, _ := json.Marshal(serviceParameters)
|
||
body["ServiceParameters"] = serviceParametersJSON
|
||
str := formDataToString(body)
|
||
req.body = []byte(*str)
|
||
req.headers["content-type"] = "application/x-www-form-urlencoded"
|
||
req.headers["User-Agent"] = cfg.AliyunUserAgent
|
||
|
||
getAuthorization(req, config.AK, config.SK, config.Token)
|
||
|
||
q := url.Values{}
|
||
keys := maps.Keys(req.queryParam)
|
||
sort.Strings(keys)
|
||
for _, k := range keys {
|
||
v := req.queryParam[k]
|
||
q.Set(k, fmt.Sprintf("%v", v))
|
||
}
|
||
for k, v := range req.headers {
|
||
if k != "host" {
|
||
headers = append(headers, [2]string{k, v})
|
||
}
|
||
}
|
||
return "?" + q.Encode(), headers, req.body
|
||
}
|
||
|
||
func GenerateRequestForImage(config cfg.AISecurityConfig, checkAction, checkService, imgUrl, imgBase64 string) (path string, headers [][2]string, reqBody []byte) {
|
||
httpMethod := "POST"
|
||
canonicalUri := "/"
|
||
xAcsVersion := "2022-03-02"
|
||
req := newRequest(httpMethod, canonicalUri, config.Host, checkAction, xAcsVersion)
|
||
|
||
req.queryParam["Service"] = checkService
|
||
|
||
body := make(map[string]interface{})
|
||
serviceParameters := make(map[string]interface{})
|
||
if imgUrl != "" {
|
||
serviceParameters["imageUrls"] = []string{imgUrl}
|
||
}
|
||
serviceParametersJSON, _ := json.Marshal(serviceParameters)
|
||
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
|
||
body["ServiceParameters"] = serviceParametersJSON
|
||
if imgBase64 != "" {
|
||
body["ImageBase64Str"] = imgBase64
|
||
}
|
||
str := formDataToString(body)
|
||
req.body = []byte(*str)
|
||
req.headers["content-type"] = "application/x-www-form-urlencoded"
|
||
req.headers["User-Agent"] = cfg.AliyunUserAgent
|
||
|
||
getAuthorization(req, config.AK, config.SK, config.Token)
|
||
|
||
q := url.Values{}
|
||
keys := maps.Keys(req.queryParam)
|
||
sort.Strings(keys)
|
||
for _, k := range keys {
|
||
v := req.queryParam[k]
|
||
q.Set(k, fmt.Sprintf("%v", v))
|
||
}
|
||
for k, v := range req.headers {
|
||
// host will be added by envoy automatically
|
||
if k != "host" {
|
||
headers = append(headers, [2]string{k, v})
|
||
}
|
||
}
|
||
return "?" + q.Encode(), headers, req.body
|
||
}
|