mirror of
https://github.com/alibaba/higress.git
synced 2026-03-15 22:30:47 +08:00
1004 lines
33 KiB
Go
1004 lines
33 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/hmac"
|
|
"crypto/rand"
|
|
"crypto/sha1"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
mrand "math/rand"
|
|
"net/http"
|
|
"net/url"
|
|
"regexp"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
|
"github.com/higress-group/wasm-go/pkg/log"
|
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
func main() {}
|
|
|
|
func init() {
|
|
wrapper.SetCtx(
|
|
"ai-security-guard",
|
|
wrapper.ParseConfig(parseConfig),
|
|
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
|
|
wrapper.ProcessRequestBody(onHttpRequestBody),
|
|
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
|
|
wrapper.ProcessStreamingResponseBody(onHttpStreamingResponseBody),
|
|
wrapper.ProcessResponseBody(onHttpResponseBody),
|
|
)
|
|
}
|
|
|
|
const (
|
|
MaxRisk = "max"
|
|
HighRisk = "high"
|
|
MediumRisk = "medium"
|
|
LowRisk = "low"
|
|
NoRisk = "none"
|
|
|
|
S4Sensitive = "S4"
|
|
S3Sensitive = "S3"
|
|
S2Sensitive = "S2"
|
|
S1Sensitive = "S1"
|
|
NoSensitive = "S0"
|
|
|
|
ContentModerationType = "contentModeration"
|
|
PromptAttackType = "promptAttack"
|
|
SensitiveDataType = "sensitiveData"
|
|
MaliciousUrlDataType = "maliciousUrl"
|
|
ModelHallucinationDataType = "modelHallucination"
|
|
|
|
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
|
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
|
|
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
|
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
|
|
|
|
DefaultRequestCheckService = "llm_query_moderation"
|
|
DefaultResponseCheckService = "llm_response_moderation"
|
|
DefaultRequestJsonPath = "messages.@reverse.0.content"
|
|
DefaultResponseJsonPath = "choices.0.message.content"
|
|
DefaultStreamingResponseJsonPath = "choices.0.delta.content"
|
|
DefaultDenyCode = 200
|
|
DefaultDenyMessage = "很抱歉,我无法回答您的问题"
|
|
DefaultTimeout = 2000
|
|
|
|
AliyunUserAgent = "CIPFrom/AIGateway"
|
|
LengthLimit = 1800
|
|
)
|
|
|
|
type Response struct {
|
|
Code int `json:"Code"`
|
|
Message string `json:"Message"`
|
|
RequestId string `json:"RequestId"`
|
|
Data Data `json:"Data"`
|
|
}
|
|
|
|
type Data struct {
|
|
RiskLevel string `json:"RiskLevel"`
|
|
AttackLevel string `json:"AttackLevel,omitempty"`
|
|
Result []Result `json:"Result,omitempty"`
|
|
Advice []Advice `json:"Advice,omitempty"`
|
|
Detail []Detail `json:"Detail,omitempty"`
|
|
}
|
|
|
|
type Result struct {
|
|
RiskWords string `json:"RiskWords,omitempty"`
|
|
Description string `json:"Description,omitempty"`
|
|
Confidence float64 `json:"Confidence,omitempty"`
|
|
Label string `json:"Label,omitempty"`
|
|
}
|
|
|
|
type Advice struct {
|
|
Answer string `json:"Answer,omitempty"`
|
|
HitLabel string `json:"HitLabel,omitempty"`
|
|
HitLibName string `json:"HitLibName,omitempty"`
|
|
}
|
|
|
|
type Detail struct {
|
|
Suggestion string `json:"Suggestion,omitempty"`
|
|
Type string `json:"Type,omitempty"`
|
|
Level string `json:"Level,omitempty"`
|
|
}
|
|
|
|
type AISecurityConfig struct {
|
|
client wrapper.HttpClient
|
|
ak string
|
|
sk string
|
|
token string
|
|
action string
|
|
checkRequest bool
|
|
requestCheckService string
|
|
requestContentJsonPath string
|
|
checkResponse bool
|
|
responseCheckService string
|
|
responseContentJsonPath string
|
|
responseStreamContentJsonPath string
|
|
denyCode int64
|
|
denyMessage string
|
|
protocolOriginal bool
|
|
riskLevelBar string
|
|
contentModerationLevelBar string
|
|
promptAttackLevelBar string
|
|
sensitiveDataLevelBar string
|
|
maliciousUrlLevelBar string
|
|
modelHallucinationLevelBar string
|
|
timeout uint32
|
|
bufferLimit int
|
|
metrics map[string]proxywasm.MetricCounter
|
|
consumerRequestCheckService []map[string]interface{}
|
|
consumerResponseCheckService []map[string]interface{}
|
|
consumerRiskLevel []map[string]interface{}
|
|
}
|
|
|
|
type Matcher struct {
|
|
Exact string
|
|
Prefix string
|
|
Re *regexp.Regexp
|
|
}
|
|
|
|
func (m *Matcher) match(consumer string) bool {
|
|
if m.Exact != "" {
|
|
return consumer == m.Exact
|
|
} else if m.Prefix != "" {
|
|
return strings.HasPrefix(consumer, m.Prefix)
|
|
} else if m.Re != nil {
|
|
return m.Re.MatchString(consumer)
|
|
} else {
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) {
|
|
counter, ok := config.metrics[metricName]
|
|
if !ok {
|
|
counter = proxywasm.DefineCounterMetric(metricName)
|
|
config.metrics[metricName] = counter
|
|
}
|
|
counter.Increment(inc)
|
|
}
|
|
|
|
func (config *AISecurityConfig) getRequestCheckService(consumer string) string {
|
|
result := config.requestCheckService
|
|
for _, obj := range config.consumerRequestCheckService {
|
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
|
if matcher.match(consumer) {
|
|
if requestCheckService, ok := obj["requestCheckService"]; ok {
|
|
result, _ = requestCheckService.(string)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (config *AISecurityConfig) getResponseCheckService(consumer string) string {
|
|
result := config.responseCheckService
|
|
for _, obj := range config.consumerResponseCheckService {
|
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
|
if matcher.match(consumer) {
|
|
if responseCheckService, ok := obj["responseCheckService"]; ok {
|
|
result, _ = responseCheckService.(string)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (config *AISecurityConfig) getRiskLevelBar(consumer string) string {
|
|
result := config.riskLevelBar
|
|
for _, obj := range config.consumerRiskLevel {
|
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
|
if matcher.match(consumer) {
|
|
if riskLevelBar, ok := obj["riskLevelBar"]; ok {
|
|
result, _ = riskLevelBar.(string)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (config *AISecurityConfig) getContentModerationLevelBar(consumer string) string {
|
|
result := config.contentModerationLevelBar
|
|
for _, obj := range config.consumerRiskLevel {
|
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
|
if matcher.match(consumer) {
|
|
if contentModerationLevelBar, ok := obj["contentModerationLevelBar"]; ok {
|
|
result, _ = contentModerationLevelBar.(string)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (config *AISecurityConfig) getPromptAttackLevelBar(consumer string) string {
|
|
result := config.promptAttackLevelBar
|
|
for _, obj := range config.consumerRiskLevel {
|
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
|
if matcher.match(consumer) {
|
|
if promptAttackLevelBar, ok := obj["promptAttackLevelBar"]; ok {
|
|
result, _ = promptAttackLevelBar.(string)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (config *AISecurityConfig) getSensitiveDataLevelBar(consumer string) string {
|
|
result := config.sensitiveDataLevelBar
|
|
for _, obj := range config.consumerRiskLevel {
|
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
|
if matcher.match(consumer) {
|
|
if sensitiveDataLevelBar, ok := obj["sensitiveDataLevelBar"]; ok {
|
|
result, _ = sensitiveDataLevelBar.(string)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (config *AISecurityConfig) getMaliciousUrlLevelBar(consumer string) string {
|
|
result := config.maliciousUrlLevelBar
|
|
for _, obj := range config.consumerRiskLevel {
|
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
|
if matcher.match(consumer) {
|
|
if maliciousUrlLevelBar, ok := obj["maliciousUrlLevelBar"]; ok {
|
|
result, _ = maliciousUrlLevelBar.(string)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (config *AISecurityConfig) getModelHallucinationLevelBar(consumer string) string {
|
|
result := config.modelHallucinationLevelBar
|
|
for _, obj := range config.consumerRiskLevel {
|
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
|
if matcher.match(consumer) {
|
|
if modelHallucinationLevelBar, ok := obj["modelHallucinationLevelBar"]; ok {
|
|
result, _ = modelHallucinationLevelBar.(string)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func levelToInt(riskLevel string) int {
|
|
// First check against our defined constants
|
|
switch riskLevel {
|
|
case MaxRisk:
|
|
return 4
|
|
case HighRisk:
|
|
return 3
|
|
case MediumRisk:
|
|
return 2
|
|
case LowRisk:
|
|
return 1
|
|
case NoRisk:
|
|
return 0
|
|
case S4Sensitive:
|
|
return 4
|
|
case S3Sensitive:
|
|
return 3
|
|
case S2Sensitive:
|
|
return 2
|
|
case S1Sensitive:
|
|
return 1
|
|
case NoSensitive:
|
|
return 0
|
|
}
|
|
|
|
// Then check against raw string values
|
|
switch riskLevel {
|
|
case "max", "MAX":
|
|
return 4
|
|
case "high", "HIGH":
|
|
return 3
|
|
case "medium", "MEDIUM":
|
|
return 2
|
|
case "low", "LOW":
|
|
return 1
|
|
case "none", "NONE":
|
|
return 0
|
|
case "S4", "s4":
|
|
return 4
|
|
case "S3", "s3":
|
|
return 3
|
|
case "S2", "s2":
|
|
return 2
|
|
case "S1", "s1":
|
|
return 1
|
|
case "S0", "s0":
|
|
return 0
|
|
default:
|
|
return -1
|
|
}
|
|
}
|
|
|
|
func isRiskLevelAcceptable(action string, data Data, config AISecurityConfig, consumer string) bool {
|
|
if action == "MultiModalGuard" {
|
|
// Check top-level risk levels for MultiModalGuard
|
|
if levelToInt(data.RiskLevel) >= levelToInt(config.getContentModerationLevelBar(consumer)) {
|
|
return false
|
|
}
|
|
// Also check AttackLevel for prompt attack detection
|
|
if levelToInt(data.AttackLevel) >= levelToInt(config.getPromptAttackLevelBar(consumer)) {
|
|
return false
|
|
}
|
|
|
|
// Check detailed results for backward compatibility
|
|
for _, detail := range data.Detail {
|
|
switch detail.Type {
|
|
case ContentModerationType:
|
|
if levelToInt(detail.Level) >= levelToInt(config.getContentModerationLevelBar(consumer)) {
|
|
return false
|
|
}
|
|
case PromptAttackType:
|
|
if levelToInt(detail.Level) >= levelToInt(config.getPromptAttackLevelBar(consumer)) {
|
|
return false
|
|
}
|
|
case SensitiveDataType:
|
|
if levelToInt(detail.Level) >= levelToInt(config.getSensitiveDataLevelBar(consumer)) {
|
|
return false
|
|
}
|
|
case MaliciousUrlDataType:
|
|
if levelToInt(detail.Level) >= levelToInt(config.getMaliciousUrlLevelBar(consumer)) {
|
|
return false
|
|
}
|
|
case ModelHallucinationDataType:
|
|
if levelToInt(detail.Level) >= levelToInt(config.getModelHallucinationLevelBar(consumer)) {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
} else {
|
|
return levelToInt(data.RiskLevel) < levelToInt(config.getRiskLevelBar(consumer))
|
|
}
|
|
}
|
|
|
|
func urlEncoding(rawStr string) string {
|
|
encodedStr := url.PathEscape(rawStr)
|
|
encodedStr = strings.ReplaceAll(encodedStr, "+", "%2B")
|
|
encodedStr = strings.ReplaceAll(encodedStr, ":", "%3A")
|
|
encodedStr = strings.ReplaceAll(encodedStr, "=", "%3D")
|
|
encodedStr = strings.ReplaceAll(encodedStr, "&", "%26")
|
|
encodedStr = strings.ReplaceAll(encodedStr, "$", "%24")
|
|
encodedStr = strings.ReplaceAll(encodedStr, "@", "%40")
|
|
return encodedStr
|
|
}
|
|
|
|
func hmacSha1(message, secret string) string {
|
|
key := []byte(secret)
|
|
h := hmac.New(sha1.New, key)
|
|
h.Write([]byte(message))
|
|
hash := h.Sum(nil)
|
|
return base64.StdEncoding.EncodeToString(hash)
|
|
}
|
|
|
|
func getSign(params map[string]string, secret string) string {
|
|
paramArray := []string{}
|
|
for k, v := range params {
|
|
paramArray = append(paramArray, urlEncoding(k)+"="+urlEncoding(v))
|
|
}
|
|
sort.Slice(paramArray, func(i, j int) bool {
|
|
return paramArray[i] <= paramArray[j]
|
|
})
|
|
canonicalStr := strings.Join(paramArray, "&")
|
|
signStr := "POST&%2F&" + urlEncoding(canonicalStr)
|
|
proxywasm.LogDebugf("String to sign is: %s", signStr)
|
|
return hmacSha1(signStr, secret)
|
|
}
|
|
|
|
func generateHexID(length int) (string, error) {
|
|
bytes := make([]byte, length/2)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(bytes), nil
|
|
}
|
|
|
|
func parseConfig(json gjson.Result, config *AISecurityConfig) error {
|
|
serviceName := json.Get("serviceName").String()
|
|
servicePort := json.Get("servicePort").Int()
|
|
serviceHost := json.Get("serviceHost").String()
|
|
if serviceName == "" || servicePort == 0 || serviceHost == "" {
|
|
return errors.New("invalid service config")
|
|
}
|
|
config.ak = json.Get("accessKey").String()
|
|
config.sk = json.Get("secretKey").String()
|
|
if config.ak == "" || config.sk == "" {
|
|
return errors.New("invalid AK/SK config")
|
|
}
|
|
if obj := json.Get("riskLevelBar"); obj.Exists() {
|
|
config.riskLevelBar = obj.String()
|
|
} else {
|
|
config.riskLevelBar = HighRisk
|
|
}
|
|
config.token = json.Get("securityToken").String()
|
|
if obj := json.Get("action"); obj.Exists() {
|
|
config.action = json.Get("action").String()
|
|
} else {
|
|
config.action = "TextModerationPlus"
|
|
}
|
|
config.checkRequest = json.Get("checkRequest").Bool()
|
|
config.checkResponse = json.Get("checkResponse").Bool()
|
|
config.protocolOriginal = json.Get("protocol").String() == "original"
|
|
config.denyMessage = json.Get("denyMessage").String()
|
|
if obj := json.Get("denyCode"); obj.Exists() {
|
|
config.denyCode = obj.Int()
|
|
} else {
|
|
config.denyCode = DefaultDenyCode
|
|
}
|
|
if obj := json.Get("requestCheckService"); obj.Exists() {
|
|
config.requestCheckService = obj.String()
|
|
} else {
|
|
config.requestCheckService = DefaultRequestCheckService
|
|
}
|
|
if obj := json.Get("responseCheckService"); obj.Exists() {
|
|
config.responseCheckService = obj.String()
|
|
} else {
|
|
config.responseCheckService = DefaultResponseCheckService
|
|
}
|
|
if obj := json.Get("requestContentJsonPath"); obj.Exists() {
|
|
config.requestContentJsonPath = obj.String()
|
|
} else {
|
|
config.requestContentJsonPath = DefaultRequestJsonPath
|
|
}
|
|
if obj := json.Get("responseContentJsonPath"); obj.Exists() {
|
|
config.responseContentJsonPath = obj.String()
|
|
} else {
|
|
config.responseContentJsonPath = DefaultResponseJsonPath
|
|
}
|
|
if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() {
|
|
config.responseStreamContentJsonPath = obj.String()
|
|
} else {
|
|
config.responseStreamContentJsonPath = DefaultStreamingResponseJsonPath
|
|
}
|
|
if obj := json.Get("contentModerationLevelBar"); obj.Exists() {
|
|
config.contentModerationLevelBar = obj.String()
|
|
if levelToInt(config.contentModerationLevelBar) <= 0 {
|
|
return errors.New("invalid contentModerationLevelBar, value must be one of [max, high, medium, low]")
|
|
}
|
|
} else {
|
|
config.contentModerationLevelBar = MaxRisk
|
|
}
|
|
if obj := json.Get("promptAttackLevelBar"); obj.Exists() {
|
|
config.promptAttackLevelBar = obj.String()
|
|
if levelToInt(config.promptAttackLevelBar) <= 0 {
|
|
return errors.New("invalid promptAttackLevelBar, value must be one of [max, high, medium, low]")
|
|
}
|
|
} else {
|
|
config.promptAttackLevelBar = MaxRisk
|
|
}
|
|
if obj := json.Get("sensitiveDataLevelBar"); obj.Exists() {
|
|
config.sensitiveDataLevelBar = obj.String()
|
|
if levelToInt(config.sensitiveDataLevelBar) <= 0 {
|
|
return errors.New("invalid sensitiveDataLevelBar, value must be one of [S4, S3, S2, S1]")
|
|
}
|
|
} else {
|
|
config.sensitiveDataLevelBar = S4Sensitive
|
|
}
|
|
if obj := json.Get("modelHallucinationLevelBar"); obj.Exists() {
|
|
config.modelHallucinationLevelBar = obj.String()
|
|
if levelToInt(config.modelHallucinationLevelBar) <= 0 {
|
|
return errors.New("invalid modelHallucinationLevelBar, value must be one of [max, high, medium, low]")
|
|
}
|
|
} else {
|
|
config.modelHallucinationLevelBar = MaxRisk
|
|
}
|
|
if obj := json.Get("maliciousUrlLevelBar"); obj.Exists() {
|
|
config.maliciousUrlLevelBar = obj.String()
|
|
if levelToInt(config.maliciousUrlLevelBar) <= 0 {
|
|
return errors.New("invalid maliciousUrlLevelBar, value must be one of [max, high, medium, low]")
|
|
}
|
|
} else {
|
|
config.maliciousUrlLevelBar = MaxRisk
|
|
}
|
|
if obj := json.Get("timeout"); obj.Exists() {
|
|
config.timeout = uint32(obj.Int())
|
|
} else {
|
|
config.timeout = DefaultTimeout
|
|
}
|
|
if obj := json.Get("bufferLimit"); obj.Exists() {
|
|
config.bufferLimit = int(obj.Int())
|
|
} else {
|
|
config.bufferLimit = 1000
|
|
}
|
|
if obj := json.Get("consumerRequestCheckService"); obj.Exists() {
|
|
for _, item := range json.Get("consumerRequestCheckService").Array() {
|
|
m := make(map[string]interface{})
|
|
for k, v := range item.Map() {
|
|
m[k] = v.Value()
|
|
}
|
|
consumerName, ok1 := m["name"]
|
|
matchType, ok2 := m["matchType"]
|
|
if !ok1 || !ok2 {
|
|
continue
|
|
}
|
|
switch fmt.Sprint(matchType) {
|
|
case "exact":
|
|
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
|
|
case "prefix":
|
|
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
|
|
case "regexp":
|
|
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
|
|
}
|
|
config.consumerRequestCheckService = append(config.consumerRequestCheckService, m)
|
|
}
|
|
}
|
|
if obj := json.Get("consumerResponseCheckService"); obj.Exists() {
|
|
for _, item := range json.Get("consumerResponseCheckService").Array() {
|
|
m := make(map[string]interface{})
|
|
for k, v := range item.Map() {
|
|
m[k] = v.Value()
|
|
}
|
|
consumerName, ok1 := m["name"]
|
|
matchType, ok2 := m["matchType"]
|
|
if !ok1 || !ok2 {
|
|
continue
|
|
}
|
|
switch fmt.Sprint(matchType) {
|
|
case "exact":
|
|
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
|
|
case "prefix":
|
|
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
|
|
case "regexp":
|
|
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
|
|
}
|
|
config.consumerResponseCheckService = append(config.consumerResponseCheckService, m)
|
|
}
|
|
}
|
|
if obj := json.Get("consumerRiskLevel"); obj.Exists() {
|
|
for _, item := range json.Get("consumerRiskLevel").Array() {
|
|
m := make(map[string]interface{})
|
|
for k, v := range item.Map() {
|
|
m[k] = v.Value()
|
|
}
|
|
consumerName, ok1 := m["name"]
|
|
matchType, ok2 := m["matchType"]
|
|
if !ok1 || !ok2 {
|
|
continue
|
|
}
|
|
switch fmt.Sprint(matchType) {
|
|
case "exact":
|
|
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
|
|
case "prefix":
|
|
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
|
|
case "regexp":
|
|
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
|
|
}
|
|
config.consumerRiskLevel = append(config.consumerRiskLevel, m)
|
|
}
|
|
}
|
|
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
|
FQDN: serviceName,
|
|
Port: servicePort,
|
|
Host: serviceHost,
|
|
})
|
|
config.metrics = make(map[string]proxywasm.MetricCounter)
|
|
return nil
|
|
}
|
|
|
|
func generateRandomID() string {
|
|
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
|
b := make([]byte, 29)
|
|
for i := range b {
|
|
b[i] = charset[mrand.Intn(len(charset))]
|
|
}
|
|
return "chatcmpl-" + string(b)
|
|
}
|
|
|
|
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig) types.Action {
|
|
consumer, _ := proxywasm.GetHttpRequestHeader("x-mse-consumer")
|
|
ctx.SetContext("consumer", consumer)
|
|
ctx.DisableReroute()
|
|
if !config.checkRequest {
|
|
log.Debugf("request checking is disabled")
|
|
ctx.DontReadRequestBody()
|
|
}
|
|
return types.ActionContinue
|
|
}
|
|
|
|
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte) types.Action {
|
|
consumer, _ := ctx.GetContext("consumer").(string)
|
|
log.Debugf("checking request body...")
|
|
startTime := time.Now().UnixMilli()
|
|
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
|
|
log.Debugf("Raw request content is: %s", content)
|
|
if len(content) == 0 {
|
|
log.Info("request content is empty. skip")
|
|
return types.ActionContinue
|
|
}
|
|
contentIndex := 0
|
|
sessionID, _ := generateHexID(20)
|
|
var singleCall func()
|
|
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
|
log.Info(string(responseBody))
|
|
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
|
proxywasm.ResumeHttpRequest()
|
|
return
|
|
}
|
|
var response Response
|
|
err := json.Unmarshal(responseBody, &response)
|
|
if err != nil {
|
|
log.Error("failed to unmarshal aliyun content security response at request phase")
|
|
proxywasm.ResumeHttpRequest()
|
|
return
|
|
}
|
|
if isRiskLevelAcceptable(config.action, response.Data, config, consumer) {
|
|
if contentIndex >= len(content) {
|
|
endTime := time.Now().UnixMilli()
|
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
|
ctx.SetUserAttribute("safecheck_status", "request pass")
|
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
|
proxywasm.ResumeHttpRequest()
|
|
} else {
|
|
singleCall()
|
|
}
|
|
return
|
|
}
|
|
denyMessage := DefaultDenyMessage
|
|
if config.denyMessage != "" {
|
|
denyMessage = config.denyMessage
|
|
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
|
denyMessage = response.Data.Advice[0].Answer
|
|
}
|
|
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
|
if config.protocolOriginal {
|
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
|
} else if gjson.GetBytes(body, "stream").Bool() {
|
|
randomID := generateRandomID()
|
|
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
|
} else {
|
|
randomID := generateRandomID()
|
|
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
|
}
|
|
ctx.DontReadResponseBody()
|
|
config.incrementCounter("ai_sec_request_deny", 1)
|
|
endTime := time.Now().UnixMilli()
|
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
|
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
|
if response.Data.Advice != nil {
|
|
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
|
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
|
}
|
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
|
}
|
|
singleCall = func() {
|
|
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
|
randomID, _ := generateHexID(16)
|
|
var nextContentIndex int
|
|
if contentIndex+LengthLimit >= len(content) {
|
|
nextContentIndex = len(content)
|
|
} else {
|
|
nextContentIndex = contentIndex + LengthLimit
|
|
}
|
|
contentPiece := content[contentIndex:nextContentIndex]
|
|
contentIndex = nextContentIndex
|
|
log.Debugf("current content piece: %s", contentPiece)
|
|
checkService := config.getRequestCheckService(consumer)
|
|
params := map[string]string{
|
|
"Format": "JSON",
|
|
"Version": "2022-03-02",
|
|
"SignatureMethod": "Hmac-SHA1",
|
|
"SignatureNonce": randomID,
|
|
"SignatureVersion": "1.0",
|
|
"Action": config.action,
|
|
"AccessKeyId": config.ak,
|
|
"Timestamp": timestamp,
|
|
"Service": checkService,
|
|
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, wrapper.MarshalStr(contentPiece), AliyunUserAgent),
|
|
}
|
|
if config.token != "" {
|
|
params["SecurityToken"] = config.token
|
|
}
|
|
signature := getSign(params, config.sk+"&")
|
|
reqParams := url.Values{}
|
|
for k, v := range params {
|
|
reqParams.Add(k, v)
|
|
}
|
|
reqParams.Add("Signature", signature)
|
|
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout)
|
|
if err != nil {
|
|
log.Errorf("failed call the safe check service: %v", err)
|
|
proxywasm.ResumeHttpRequest()
|
|
}
|
|
}
|
|
singleCall()
|
|
return types.ActionPause
|
|
}
|
|
|
|
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig) types.Action {
|
|
if !config.checkResponse {
|
|
log.Debugf("response checking is disabled")
|
|
ctx.DontReadResponseBody()
|
|
return types.ActionContinue
|
|
}
|
|
statusCode, _ := proxywasm.GetHttpResponseHeader(":status")
|
|
if statusCode != "200" {
|
|
log.Debugf("response is not 200, skip response body check")
|
|
ctx.DontReadResponseBody()
|
|
return types.ActionContinue
|
|
}
|
|
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
|
ctx.SetContext("end_of_stream_received", false)
|
|
ctx.SetContext("during_call", false)
|
|
ctx.SetContext("risk_detected", false)
|
|
sessionID, _ := generateHexID(20)
|
|
ctx.SetContext("sessionID", sessionID)
|
|
if strings.Contains(contentType, "text/event-stream") {
|
|
ctx.NeedPauseStreamingResponse()
|
|
return types.ActionContinue
|
|
} else {
|
|
ctx.BufferResponseBody()
|
|
return types.HeaderStopIteration
|
|
}
|
|
}
|
|
|
|
func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, data []byte, endOfStream bool) []byte {
|
|
consumer, _ := ctx.GetContext("consumer").(string)
|
|
var bufferQueue [][]byte
|
|
var singleCall func()
|
|
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
|
log.Info(string(responseBody))
|
|
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
|
if ctx.GetContext("end_of_stream_received").(bool) {
|
|
proxywasm.ResumeHttpResponse()
|
|
}
|
|
ctx.SetContext("during_call", false)
|
|
return
|
|
}
|
|
var response Response
|
|
err := json.Unmarshal(responseBody, &response)
|
|
if err != nil {
|
|
log.Error("failed to unmarshal aliyun content security response at response phase")
|
|
if ctx.GetContext("end_of_stream_received").(bool) {
|
|
proxywasm.ResumeHttpResponse()
|
|
}
|
|
ctx.SetContext("during_call", false)
|
|
return
|
|
}
|
|
if !isRiskLevelAcceptable(config.action, response.Data, config, consumer) {
|
|
denyMessage := DefaultDenyMessage
|
|
if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
|
denyMessage = "\n" + response.Data.Advice[0].Answer
|
|
} else if config.denyMessage != "" {
|
|
denyMessage = config.denyMessage
|
|
}
|
|
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
|
randomID := generateRandomID()
|
|
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
|
proxywasm.InjectEncodedDataToFilterChain(jsonData, true)
|
|
return
|
|
}
|
|
endStream := ctx.GetContext("end_of_stream_received").(bool) && ctx.BufferQueueSize() == 0
|
|
proxywasm.InjectEncodedDataToFilterChain(bytes.Join(bufferQueue, []byte("")), endStream)
|
|
bufferQueue = [][]byte{}
|
|
if !endStream {
|
|
ctx.SetContext("during_call", false)
|
|
singleCall()
|
|
}
|
|
}
|
|
singleCall = func() {
|
|
if ctx.GetContext("during_call").(bool) {
|
|
return
|
|
}
|
|
if ctx.BufferQueueSize() >= config.bufferLimit || ctx.GetContext("end_of_stream_received").(bool) {
|
|
var buffer string
|
|
for ctx.BufferQueueSize() > 0 {
|
|
front := ctx.PopBuffer()
|
|
bufferQueue = append(bufferQueue, front)
|
|
msg := gjson.GetBytes(front, config.responseStreamContentJsonPath).String()
|
|
buffer += msg
|
|
if len([]rune(buffer)) >= config.bufferLimit {
|
|
break
|
|
}
|
|
}
|
|
// if streaming body has reasoning_content, buffer maybe empty
|
|
log.Debugf("current content piece: %s", buffer)
|
|
if len(buffer) == 0 {
|
|
return
|
|
}
|
|
ctx.SetContext("during_call", true)
|
|
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
|
randomID, _ := generateHexID(16)
|
|
log.Debugf("current content piece: %s", buffer)
|
|
checkService := config.getResponseCheckService(consumer)
|
|
params := map[string]string{
|
|
"Format": "JSON",
|
|
"Version": "2022-03-02",
|
|
"SignatureMethod": "Hmac-SHA1",
|
|
"SignatureNonce": randomID,
|
|
"SignatureVersion": "1.0",
|
|
"Action": config.action,
|
|
"AccessKeyId": config.ak,
|
|
"Timestamp": timestamp,
|
|
"Service": checkService,
|
|
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, ctx.GetContext("sessionID").(string), wrapper.MarshalStr(buffer), AliyunUserAgent),
|
|
}
|
|
if config.token != "" {
|
|
params["SecurityToken"] = config.token
|
|
}
|
|
signature := getSign(params, config.sk+"&")
|
|
reqParams := url.Values{}
|
|
for k, v := range params {
|
|
reqParams.Add(k, v)
|
|
}
|
|
reqParams.Add("Signature", signature)
|
|
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout)
|
|
if err != nil {
|
|
log.Errorf("failed call the safe check service: %v", err)
|
|
if ctx.GetContext("end_of_stream_received").(bool) {
|
|
proxywasm.ResumeHttpResponse()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if !ctx.GetContext("risk_detected").(bool) {
|
|
for _, chunk := range bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) {
|
|
ctx.PushBuffer([]byte(string(chunk) + "\n\n"))
|
|
}
|
|
ctx.SetContext("end_of_stream_received", endOfStream)
|
|
if !ctx.GetContext("during_call").(bool) {
|
|
singleCall()
|
|
}
|
|
} else if endOfStream {
|
|
proxywasm.ResumeHttpResponse()
|
|
}
|
|
return []byte{}
|
|
}
|
|
|
|
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte) types.Action {
|
|
consumer, _ := ctx.GetContext("consumer").(string)
|
|
log.Debugf("checking response body...")
|
|
startTime := time.Now().UnixMilli()
|
|
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
|
isStreamingResponse := strings.Contains(contentType, "event-stream")
|
|
var content string
|
|
if isStreamingResponse {
|
|
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
|
|
} else {
|
|
content = gjson.GetBytes(body, config.responseContentJsonPath).String()
|
|
}
|
|
log.Debugf("Raw response content is: %s", content)
|
|
if len(content) == 0 {
|
|
log.Info("response content is empty. skip")
|
|
return types.ActionContinue
|
|
}
|
|
contentIndex := 0
|
|
sessionID, _ := generateHexID(20)
|
|
var singleCall func()
|
|
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
|
log.Info(string(responseBody))
|
|
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
|
proxywasm.ResumeHttpResponse()
|
|
return
|
|
}
|
|
var response Response
|
|
err := json.Unmarshal(responseBody, &response)
|
|
if err != nil {
|
|
log.Error("failed to unmarshal aliyun content security response at response phase")
|
|
proxywasm.ResumeHttpResponse()
|
|
return
|
|
}
|
|
if isRiskLevelAcceptable(config.action, response.Data, config, consumer) {
|
|
if contentIndex >= len(content) {
|
|
endTime := time.Now().UnixMilli()
|
|
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
|
|
ctx.SetUserAttribute("safecheck_status", "response pass")
|
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
|
proxywasm.ResumeHttpResponse()
|
|
} else {
|
|
singleCall()
|
|
}
|
|
return
|
|
}
|
|
denyMessage := DefaultDenyMessage
|
|
if config.denyMessage != "" {
|
|
denyMessage = config.denyMessage
|
|
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
|
denyMessage = response.Data.Advice[0].Answer
|
|
}
|
|
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
|
if config.protocolOriginal {
|
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
|
} else if isStreamingResponse {
|
|
randomID := generateRandomID()
|
|
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
|
} else {
|
|
randomID := generateRandomID()
|
|
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
|
}
|
|
config.incrementCounter("ai_sec_response_deny", 1)
|
|
endTime := time.Now().UnixMilli()
|
|
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
|
|
ctx.SetUserAttribute("safecheck_status", "response deny")
|
|
if response.Data.Advice != nil {
|
|
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
|
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
|
}
|
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
|
}
|
|
singleCall = func() {
|
|
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
|
randomID, _ := generateHexID(16)
|
|
var nextContentIndex int
|
|
if contentIndex+LengthLimit >= len(content) {
|
|
nextContentIndex = len(content)
|
|
} else {
|
|
nextContentIndex = contentIndex + LengthLimit
|
|
}
|
|
contentPiece := content[contentIndex:nextContentIndex]
|
|
contentIndex = nextContentIndex
|
|
log.Debugf("current content piece: %s", contentPiece)
|
|
checkService := config.getResponseCheckService(consumer)
|
|
params := map[string]string{
|
|
"Format": "JSON",
|
|
"Version": "2022-03-02",
|
|
"SignatureMethod": "Hmac-SHA1",
|
|
"SignatureNonce": randomID,
|
|
"SignatureVersion": "1.0",
|
|
"Action": config.action,
|
|
"AccessKeyId": config.ak,
|
|
"Timestamp": timestamp,
|
|
"Service": checkService,
|
|
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, wrapper.MarshalStr(contentPiece), AliyunUserAgent),
|
|
}
|
|
if config.token != "" {
|
|
params["SecurityToken"] = config.token
|
|
}
|
|
signature := getSign(params, config.sk+"&")
|
|
reqParams := url.Values{}
|
|
for k, v := range params {
|
|
reqParams.Add(k, v)
|
|
}
|
|
reqParams.Add("Signature", signature)
|
|
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout)
|
|
if err != nil {
|
|
log.Errorf("failed call the safe check service: %v", err)
|
|
proxywasm.ResumeHttpResponse()
|
|
}
|
|
}
|
|
singleCall()
|
|
return types.ActionPause
|
|
}
|
|
|
|
func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
|
|
chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n"))
|
|
strChunks := []string{}
|
|
for _, chunk := range chunks {
|
|
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
|
|
strChunks = append(strChunks, gjson.GetBytes(chunk, jsonPath).String())
|
|
}
|
|
return strings.Join(strChunks, "")
|
|
}
|