[feat] ai-security-guard refactor & support checking multimoadl input (#3075)

This commit is contained in:
rinfx
2025-12-04 16:33:59 +08:00
committed by GitHub
parent 3e24d66079
commit 896bcacf4c
15 changed files with 1932 additions and 1014 deletions

View File

@@ -0,0 +1,585 @@
package config
import (
"errors"
"fmt"
"regexp"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
MaxRisk = "max"
HighRisk = "high"
MediumRisk = "medium"
LowRisk = "low"
NoRisk = "none"
S4Sensitive = "s4"
S3Sensitive = "s3"
S2Sensitive = "s2"
S1Sensitive = "s1"
NoSensitive = "s0"
ContentModerationType = "contentModeration"
PromptAttackType = "promptAttack"
SensitiveDataType = "sensitiveData"
MaliciousUrlDataType = "maliciousUrl"
ModelHallucinationDataType = "modelHallucination"
// Default configurations
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
DefaultDenyCode = 200
DefaultDenyMessage = "很抱歉,我无法回答您的问题"
DefaultTimeout = 2000
AliyunUserAgent = "CIPFrom/AIGateway"
LengthLimit = 1800
DefaultRequestCheckService = "llm_query_moderation"
DefaultResponseCheckService = "llm_response_moderation"
DefaultRequestJsonPath = "messages.@reverse.0.content"
DefaultResponseJsonPath = "choices.0.message.content"
DefaultStreamingResponseJsonPath = "choices.0.delta.content"
// Actions
MultiModalGuard = "MultiModalGuard"
MultiModalGuardForBase64 = "MultiModalGuardForBase64"
TextModerationPlus = "TextModerationPlus"
// Services
DefaultMultiModalGuardTextInputCheckService = "query_security_check"
DefaultMultiModalGuardTextOutputCheckService = "response_security_check"
DefaultMultiModalGuardImageInputCheckService = "img_query_security_check"
DefaultTextModerationPlusTextInputCheckService = "llm_query_moderation"
DefaultTextModerationPlusTextOutputCheckService = "llm_response_moderation"
)
// api types
const (
ApiTextGeneration = "text_generation"
ApiImageGeneration = "image_generation"
)
// provider types
const (
ProviderOpenAI = "openai"
ProviderQwen = "qwen"
ProviderComfyUI = "comfyui"
)
type Response struct {
Code int `json:"Code"`
Message string `json:"Message"`
RequestId string `json:"RequestId"`
Data Data `json:"Data"`
}
type Data struct {
RiskLevel string `json:"RiskLevel,omitempty"`
AttackLevel string `json:"AttackLevel,omitempty"`
Result []Result `json:"Result,omitempty"`
Advice []Advice `json:"Advice,omitempty"`
Detail []Detail `json:"Detail,omitempty"`
}
type Result struct {
RiskWords string `json:"RiskWords,omitempty"`
Description string `json:"Description,omitempty"`
Confidence float64 `json:"Confidence,omitempty"`
Label string `json:"Label,omitempty"`
}
type Advice struct {
Answer string `json:"Answer,omitempty"`
HitLabel string `json:"HitLabel,omitempty"`
HitLibName string `json:"HitLibName,omitempty"`
}
type Detail struct {
Suggestion string `json:"Suggestion,omitempty"`
Type string `json:"Type,omitempty"`
Level string `json:"Level,omitempty"`
}
type Matcher struct {
Exact string
Prefix string
Re *regexp.Regexp
}
func (m *Matcher) match(consumer string) bool {
if m.Exact != "" {
return consumer == m.Exact
} else if m.Prefix != "" {
return strings.HasPrefix(consumer, m.Prefix)
} else if m.Re != nil {
return m.Re.MatchString(consumer)
} else {
return false
}
}
type AISecurityConfig struct {
Client wrapper.HttpClient
Host string
AK string
SK string
Token string
Action string
CheckRequest bool
CheckRequestImage bool
RequestCheckService string
RequestImageCheckService string
RequestContentJsonPath string
CheckResponse bool
ResponseCheckService string
ResponseImageCheckService string
ResponseContentJsonPath string
ResponseStreamContentJsonPath string
DenyCode int64
DenyMessage string
ProtocolOriginal bool
RiskLevelBar string
ContentModerationLevelBar string
PromptAttackLevelBar string
SensitiveDataLevelBar string
MaliciousUrlLevelBar string
ModelHallucinationLevelBar string
Timeout uint32
BufferLimit int
Metrics map[string]proxywasm.MetricCounter
ConsumerRequestCheckService []map[string]interface{}
ConsumerResponseCheckService []map[string]interface{}
ConsumerRiskLevel []map[string]interface{}
// text_generation, image_generation, etc.
ApiType string
// openai, qwen, comfyui, etc.
ProviderType string
}
func (config *AISecurityConfig) Parse(json gjson.Result) error {
serviceName := json.Get("serviceName").String()
servicePort := json.Get("servicePort").Int()
serviceHost := json.Get("serviceHost").String()
config.Host = serviceHost
if serviceName == "" || servicePort == 0 || serviceHost == "" {
return errors.New("invalid service config")
}
config.AK = json.Get("accessKey").String()
config.SK = json.Get("secretKey").String()
if config.AK == "" || config.SK == "" {
return errors.New("invalid AK/SK config")
}
config.Token = json.Get("securityToken").String()
// set action
if obj := json.Get("action"); obj.Exists() {
config.Action = json.Get("action").String()
} else {
config.Action = TextModerationPlus
}
// set default values
config.SetDefaultValues()
// set values
if obj := json.Get("riskLevelBar"); obj.Exists() {
config.RiskLevelBar = obj.String()
}
if obj := json.Get("requestCheckService"); obj.Exists() {
config.RequestCheckService = obj.String()
}
if obj := json.Get("requestImageCheckService"); obj.Exists() {
config.RequestImageCheckService = obj.String()
}
if obj := json.Get("responseCheckService"); obj.Exists() {
config.ResponseCheckService = obj.String()
}
if obj := json.Get("responseImageCheckService"); obj.Exists() {
config.ResponseImageCheckService = obj.String()
}
config.CheckRequest = json.Get("checkRequest").Bool()
config.CheckRequestImage = json.Get("checkRequestImage").Bool()
config.CheckResponse = json.Get("checkResponse").Bool()
config.ProtocolOriginal = json.Get("protocol").String() == "original"
config.DenyMessage = json.Get("denyMessage").String()
if obj := json.Get("denyCode"); obj.Exists() {
config.DenyCode = obj.Int()
}
if obj := json.Get("requestContentJsonPath"); obj.Exists() {
config.RequestContentJsonPath = obj.String()
}
if obj := json.Get("responseContentJsonPath"); obj.Exists() {
config.ResponseContentJsonPath = obj.String()
}
if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() {
config.ResponseStreamContentJsonPath = obj.String()
}
if obj := json.Get("contentModerationLevelBar"); obj.Exists() {
config.ContentModerationLevelBar = obj.String()
if LevelToInt(config.ContentModerationLevelBar) <= 0 {
return errors.New("invalid contentModerationLevelBar, value must be one of [max, high, medium, low]")
}
}
if obj := json.Get("promptAttackLevelBar"); obj.Exists() {
config.PromptAttackLevelBar = obj.String()
if LevelToInt(config.PromptAttackLevelBar) <= 0 {
return errors.New("invalid promptAttackLevelBar, value must be one of [max, high, medium, low]")
}
}
if obj := json.Get("sensitiveDataLevelBar"); obj.Exists() {
config.SensitiveDataLevelBar = obj.String()
if LevelToInt(config.SensitiveDataLevelBar) <= 0 {
return errors.New("invalid sensitiveDataLevelBar, value must be one of [S4, S3, S2, S1]")
}
}
if obj := json.Get("modelHallucinationLevelBar"); obj.Exists() {
config.ModelHallucinationLevelBar = obj.String()
if LevelToInt(config.ModelHallucinationLevelBar) <= 0 {
return errors.New("invalid modelHallucinationLevelBar, value must be one of [max, high, medium, low]")
}
}
if obj := json.Get("maliciousUrlLevelBar"); obj.Exists() {
config.MaliciousUrlLevelBar = obj.String()
if LevelToInt(config.MaliciousUrlLevelBar) <= 0 {
return errors.New("invalid maliciousUrlLevelBar, value must be one of [max, high, medium, low]")
}
}
if obj := json.Get("timeout"); obj.Exists() {
config.Timeout = uint32(obj.Int())
}
if obj := json.Get("bufferLimit"); obj.Exists() {
config.BufferLimit = int(obj.Int())
}
if obj := json.Get("consumerRequestCheckService"); obj.Exists() {
for _, item := range json.Get("consumerRequestCheckService").Array() {
m := make(map[string]interface{})
for k, v := range item.Map() {
m[k] = v.Value()
}
consumerName, ok1 := m["name"]
matchType, ok2 := m["matchType"]
if !ok1 || !ok2 {
continue
}
switch fmt.Sprint(matchType) {
case "exact":
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
case "prefix":
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
case "regexp":
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
}
config.ConsumerRequestCheckService = append(config.ConsumerRequestCheckService, m)
}
}
if obj := json.Get("consumerResponseCheckService"); obj.Exists() {
for _, item := range json.Get("consumerResponseCheckService").Array() {
m := make(map[string]interface{})
for k, v := range item.Map() {
m[k] = v.Value()
}
consumerName, ok1 := m["name"]
matchType, ok2 := m["matchType"]
if !ok1 || !ok2 {
continue
}
switch fmt.Sprint(matchType) {
case "exact":
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
case "prefix":
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
case "regexp":
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
}
config.ConsumerResponseCheckService = append(config.ConsumerResponseCheckService, m)
}
}
if obj := json.Get("consumerRiskLevel"); obj.Exists() {
for _, item := range json.Get("consumerRiskLevel").Array() {
m := make(map[string]interface{})
for k, v := range item.Map() {
m[k] = v.Value()
}
consumerName, ok1 := m["name"]
matchType, ok2 := m["matchType"]
if !ok1 || !ok2 {
continue
}
switch fmt.Sprint(matchType) {
case "exact":
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
case "prefix":
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
case "regexp":
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
}
config.ConsumerRiskLevel = append(config.ConsumerRiskLevel, m)
}
}
if obj := json.Get("apiType"); obj.Exists() {
config.ApiType = obj.String()
}
if obj := json.Get("providerType"); obj.Exists() {
config.ProviderType = obj.String()
}
config.Client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: serviceName,
Port: servicePort,
Host: serviceHost,
})
config.Metrics = make(map[string]proxywasm.MetricCounter)
return nil
}
func (config *AISecurityConfig) SetDefaultValues() {
switch config.Action {
case TextModerationPlus:
config.RequestCheckService = DefaultTextModerationPlusTextInputCheckService
config.ResponseCheckService = DefaultTextModerationPlusTextOutputCheckService
case MultiModalGuard:
config.RequestCheckService = DefaultMultiModalGuardTextInputCheckService
config.RequestImageCheckService = DefaultMultiModalGuardImageInputCheckService
config.ResponseCheckService = DefaultMultiModalGuardTextOutputCheckService
}
config.RiskLevelBar = HighRisk
config.DenyCode = DefaultDenyCode
config.RequestContentJsonPath = DefaultRequestJsonPath
config.ResponseContentJsonPath = DefaultResponseJsonPath
config.ResponseStreamContentJsonPath = DefaultStreamingResponseJsonPath
config.ContentModerationLevelBar = MaxRisk
config.PromptAttackLevelBar = MaxRisk
config.SensitiveDataLevelBar = S4Sensitive
config.ModelHallucinationLevelBar = MaxRisk
config.MaliciousUrlLevelBar = MaxRisk
config.Timeout = DefaultTimeout
config.BufferLimit = 1000
config.ApiType = ApiTextGeneration
config.ProviderType = ProviderOpenAI
}
func (config *AISecurityConfig) IncrementCounter(metricName string, inc uint64) {
counter, ok := config.Metrics[metricName]
if !ok {
counter = proxywasm.DefineCounterMetric(metricName)
config.Metrics[metricName] = counter
}
counter.Increment(inc)
}
func (config *AISecurityConfig) GetRequestCheckService(consumer string) string {
result := config.RequestCheckService
for _, obj := range config.ConsumerRequestCheckService {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if requestCheckService, ok := obj["requestCheckService"]; ok {
result, _ = requestCheckService.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetRequestImageCheckService(consumer string) string {
result := config.RequestImageCheckService
for _, obj := range config.ConsumerRequestCheckService {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if requestCheckService, ok := obj["requestImageCheckService"]; ok {
result, _ = requestCheckService.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetResponseCheckService(consumer string) string {
result := config.ResponseCheckService
for _, obj := range config.ConsumerResponseCheckService {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if responseCheckService, ok := obj["responseCheckService"]; ok {
result, _ = responseCheckService.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetResponseImageCheckService(consumer string) string {
result := config.ResponseImageCheckService
for _, obj := range config.ConsumerResponseCheckService {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if responseCheckService, ok := obj["responseImageCheckService"]; ok {
result, _ = responseCheckService.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetRiskLevelBar(consumer string) string {
result := config.RiskLevelBar
for _, obj := range config.ConsumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if riskLevelBar, ok := obj["riskLevelBar"]; ok {
result, _ = riskLevelBar.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetContentModerationLevelBar(consumer string) string {
result := config.ContentModerationLevelBar
for _, obj := range config.ConsumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if contentModerationLevelBar, ok := obj["contentModerationLevelBar"]; ok {
result, _ = contentModerationLevelBar.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetPromptAttackLevelBar(consumer string) string {
result := config.PromptAttackLevelBar
for _, obj := range config.ConsumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if promptAttackLevelBar, ok := obj["promptAttackLevelBar"]; ok {
result, _ = promptAttackLevelBar.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetSensitiveDataLevelBar(consumer string) string {
result := config.SensitiveDataLevelBar
for _, obj := range config.ConsumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if sensitiveDataLevelBar, ok := obj["sensitiveDataLevelBar"]; ok {
result, _ = sensitiveDataLevelBar.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetMaliciousUrlLevelBar(consumer string) string {
result := config.MaliciousUrlLevelBar
for _, obj := range config.ConsumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if maliciousUrlLevelBar, ok := obj["maliciousUrlLevelBar"]; ok {
result, _ = maliciousUrlLevelBar.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetModelHallucinationLevelBar(consumer string) string {
result := config.ModelHallucinationLevelBar
for _, obj := range config.ConsumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if modelHallucinationLevelBar, ok := obj["modelHallucinationLevelBar"]; ok {
result, _ = modelHallucinationLevelBar.(string)
}
break
}
}
}
return result
}
func LevelToInt(riskLevel string) int {
// First check against our defined constants
switch strings.ToLower(riskLevel) {
case MaxRisk, S4Sensitive:
return 4
case HighRisk, S3Sensitive:
return 3
case MediumRisk, S2Sensitive:
return 2
case LowRisk, S1Sensitive:
return 1
case NoRisk, NoSensitive:
return 0
default:
return -1
}
}
func IsRiskLevelAcceptable(action string, data Data, config AISecurityConfig, consumer string) bool {
if action == MultiModalGuard || action == MultiModalGuardForBase64 {
// Check top-level risk levels for MultiModalGuard
if LevelToInt(data.RiskLevel) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) {
return false
}
// Also check AttackLevel for prompt attack detection
if LevelToInt(data.AttackLevel) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) {
return false
}
// Check detailed results for backward compatibility
for _, detail := range data.Detail {
switch detail.Type {
case ContentModerationType:
if LevelToInt(detail.Level) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) {
return false
}
case PromptAttackType:
if LevelToInt(detail.Level) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) {
return false
}
case SensitiveDataType:
if LevelToInt(detail.Level) >= LevelToInt(config.GetSensitiveDataLevelBar(consumer)) {
return false
}
case MaliciousUrlDataType:
if LevelToInt(detail.Level) >= LevelToInt(config.GetMaliciousUrlLevelBar(consumer)) {
return false
}
case ModelHallucinationDataType:
if LevelToInt(detail.Level) >= LevelToInt(config.GetModelHallucinationLevelBar(consumer)) {
return false
}
}
}
return true
} else {
return LevelToInt(data.RiskLevel) < LevelToInt(config.GetRiskLevelBar(consumer))
}
}

View File

@@ -5,8 +5,8 @@ go 1.24.1
toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
)
@@ -20,5 +20,6 @@ require (
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -4,8 +4,12 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd h1:acTs8sqXf+qP+IypxFg3cu5Cluj7VT5BI+IDRlY5sag=
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
@@ -24,6 +28,8 @@ github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -0,0 +1,249 @@
package common
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"sort"
"golang.org/x/exp/maps"
"fmt"
"net/url"
"strings"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/google/uuid"
)
const (
ALGORITHM = "ACS3-HMAC-SHA256"
)
type Request struct {
httpMethod string
canonicalUri string
host string
xAcsAction string
xAcsVersion string
headers map[string]string
body []byte
queryParam map[string]interface{}
}
func newRequest(httpMethod, canonicalUri, host, xAcsAction, xAcsVersion string) *Request {
req := &Request{
httpMethod: httpMethod,
canonicalUri: canonicalUri,
host: host,
xAcsAction: xAcsAction,
xAcsVersion: xAcsVersion,
headers: make(map[string]string),
queryParam: make(map[string]interface{}),
}
req.headers["host"] = host
req.headers["x-acs-action"] = xAcsAction
req.headers["x-acs-version"] = xAcsVersion
req.headers["x-acs-date"] = time.Now().UTC().Format(time.RFC3339)
req.headers["x-acs-signature-nonce"] = uuid.New().String()
return req
}
func getAuthorization(req *Request, AccessKeyId, AccessKeySecret, SecurityToken string) {
newQueryParams := make(map[string]interface{})
processObject(newQueryParams, "", req.queryParam)
req.queryParam = newQueryParams
canonicalQueryString := ""
keys := maps.Keys(req.queryParam)
sort.Strings(keys)
for _, k := range keys {
v := req.queryParam[k]
canonicalQueryString += percentCode(url.QueryEscape(k)) + "=" + percentCode(url.QueryEscape(fmt.Sprintf("%v", v))) + "&"
}
canonicalQueryString = strings.TrimSuffix(canonicalQueryString, "&")
var bodyContent []byte
if req.body == nil {
bodyContent = []byte("")
} else {
bodyContent = req.body
}
hashedRequestPayload := sha256Hex(bodyContent)
req.headers["x-acs-content-sha256"] = hashedRequestPayload
if SecurityToken != "" {
req.headers["x-acs-security-token"] = SecurityToken
}
canonicalHeaders := ""
signedHeaders := ""
HeadersKeys := maps.Keys(req.headers)
sort.Strings(HeadersKeys)
for _, k := range HeadersKeys {
lowerKey := strings.ToLower(k)
if lowerKey == "host" || strings.HasPrefix(lowerKey, "x-acs-") || lowerKey == "content-type" {
canonicalHeaders += lowerKey + ":" + req.headers[k] + "\n"
signedHeaders += lowerKey + ";"
}
}
signedHeaders = strings.TrimSuffix(signedHeaders, ";")
canonicalRequest := req.httpMethod + "\n" + req.canonicalUri + "\n" + canonicalQueryString + "\n" + canonicalHeaders + "\n" + signedHeaders + "\n" + hashedRequestPayload
hashedCanonicalRequest := sha256Hex([]byte(canonicalRequest))
stringToSign := ALGORITHM + "\n" + hashedCanonicalRequest
byteData, err := hmac256([]byte(AccessKeySecret), stringToSign)
if err != nil {
fmt.Println(err)
panic(err)
}
signature := strings.ToLower(hex.EncodeToString(byteData))
authorization := ALGORITHM + " Credential=" + AccessKeyId + ",SignedHeaders=" + signedHeaders + ",Signature=" + signature
req.headers["Authorization"] = authorization
}
func hmac256(key []byte, toSignString string) ([]byte, error) {
h := hmac.New(sha256.New, key)
_, err := h.Write([]byte(toSignString))
if err != nil {
return nil, err
}
return h.Sum(nil), nil
}
func sha256Hex(byteArray []byte) string {
hash := sha256.New()
_, _ = hash.Write(byteArray)
hexString := hex.EncodeToString(hash.Sum(nil))
return hexString
}
func percentCode(str string) string {
str = strings.ReplaceAll(str, "+", "%20")
str = strings.ReplaceAll(str, "*", "%2A")
str = strings.ReplaceAll(str, "%7E", "~")
return str
}
func formDataToString(formData map[string]interface{}) *string {
tmp := make(map[string]interface{})
processObject(tmp, "", formData)
res := ""
urlEncoder := url.Values{}
for key, value := range tmp {
v := fmt.Sprintf("%v", value)
urlEncoder.Add(key, v)
}
res = urlEncoder.Encode()
return &res
}
// processObject 递归处理对象将复杂对象如Map和List展开为平面的键值对
func processObject(mapResult map[string]interface{}, key string, value interface{}) {
if value == nil {
return
}
switch v := value.(type) {
case []interface{}:
for i, item := range v {
processObject(mapResult, fmt.Sprintf("%s.%d", key, i+1), item)
}
case map[string]interface{}:
for subKey, subValue := range v {
processObject(mapResult, fmt.Sprintf("%s.%s", key, subKey), subValue)
}
default:
if strings.HasPrefix(key, ".") {
key = key[1:]
}
if b, ok := v.([]byte); ok {
mapResult[key] = string(b)
} else {
mapResult[key] = fmt.Sprintf("%v", v)
}
}
}
func GenerateRequestForText(config cfg.AISecurityConfig, checkAction, checkService, text, sessionID string) (path string, headers [][2]string, reqBody []byte) {
httpMethod := "POST"
canonicalUri := "/"
xAcsVersion := "2022-03-02"
req := newRequest(httpMethod, canonicalUri, config.Host, checkAction, xAcsVersion)
req.queryParam["Service"] = checkService
body := make(map[string]interface{})
serviceParameters := make(map[string]interface{})
serviceParameters["content"] = text
serviceParameters["sessionId"] = sessionID
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
serviceParametersJSON, _ := json.Marshal(serviceParameters)
body["ServiceParameters"] = serviceParametersJSON
str := formDataToString(body)
req.body = []byte(*str)
req.headers["content-type"] = "application/x-www-form-urlencoded"
req.headers["User-Agent"] = cfg.AliyunUserAgent
getAuthorization(req, config.AK, config.SK, config.Token)
q := url.Values{}
keys := maps.Keys(req.queryParam)
sort.Strings(keys)
for _, k := range keys {
v := req.queryParam[k]
q.Set(k, fmt.Sprintf("%v", v))
}
for k, v := range req.headers {
if k != "host" {
headers = append(headers, [2]string{k, v})
}
}
return "?" + q.Encode(), headers, req.body
}
func GenerateRequestForImage(config cfg.AISecurityConfig, checkAction, checkService, imgUrl, imgBase64 string) (path string, headers [][2]string, reqBody []byte) {
httpMethod := "POST"
canonicalUri := "/"
xAcsVersion := "2022-03-02"
req := newRequest(httpMethod, canonicalUri, config.Host, checkAction, xAcsVersion)
req.queryParam["Service"] = checkService
body := make(map[string]interface{})
serviceParameters := make(map[string]interface{})
if imgUrl != "" {
serviceParameters["imageUrls"] = []string{imgUrl}
}
serviceParametersJSON, _ := json.Marshal(serviceParameters)
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
body["ServiceParameters"] = serviceParametersJSON
if imgBase64 != "" {
body["ImageBase64Str"] = imgBase64
}
str := formDataToString(body)
req.body = []byte(*str)
req.headers["content-type"] = "application/x-www-form-urlencoded"
req.headers["User-Agent"] = cfg.AliyunUserAgent
getAuthorization(req, config.AK, config.SK, config.Token)
q := url.Values{}
keys := maps.Keys(req.queryParam)
sort.Strings(keys)
for _, k := range keys {
v := req.queryParam[k]
q.Set(k, fmt.Sprintf("%v", v))
}
for k, v := range req.headers {
// host will be added by envoy automatically
if k != "host" {
headers = append(headers, [2]string{k, v})
}
}
return "?" + q.Encode(), headers, req.body
}

View File

@@ -0,0 +1,249 @@
package text
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
ctx.SetContext("end_of_stream_received", false)
ctx.SetContext("during_call", false)
ctx.SetContext("risk_detected", false)
sessionID, _ := utils.GenerateHexID(20)
ctx.SetContext("sessionID", sessionID)
if strings.Contains(contentType, "text/event-stream") {
ctx.NeedPauseStreamingResponse()
return types.ActionContinue
} else {
ctx.BufferResponseBody()
return types.HeaderStopIteration
}
}
func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
consumer, _ := ctx.GetContext("consumer").(string)
var sessionID string
if ctx.GetContext("sessionID") == nil {
sessionID, _ = utils.GenerateHexID(20)
ctx.SetContext("sessionID", sessionID)
} else {
sessionID, _ = ctx.GetContext("sessionID").(string)
}
var bufferQueue [][]byte
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
if ctx.GetContext("end_of_stream_received").(bool) {
proxywasm.ResumeHttpResponse()
}
ctx.SetContext("during_call", false)
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at response phase")
if ctx.GetContext("end_of_stream_received").(bool) {
proxywasm.ResumeHttpResponse()
}
ctx.SetContext("during_call", false)
return
}
if !cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
denyMessage := cfg.DefaultDenyMessage
if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = "\n" + response.Data.Advice[0].Answer
} else if config.DenyMessage != "" {
denyMessage = config.DenyMessage
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.InjectEncodedDataToFilterChain(jsonData, true)
return
}
endStream := ctx.GetContext("end_of_stream_received").(bool) && ctx.BufferQueueSize() == 0
proxywasm.InjectEncodedDataToFilterChain(bytes.Join(bufferQueue, []byte("")), endStream)
bufferQueue = [][]byte{}
if !endStream {
ctx.SetContext("during_call", false)
singleCall()
}
}
singleCall = func() {
if ctx.GetContext("during_call").(bool) {
return
}
if ctx.BufferQueueSize() >= config.BufferLimit || ctx.GetContext("end_of_stream_received").(bool) {
var buffer string
for ctx.BufferQueueSize() > 0 {
front := ctx.PopBuffer()
bufferQueue = append(bufferQueue, front)
msg := gjson.GetBytes(front, config.ResponseStreamContentJsonPath).String()
buffer += msg
if len([]rune(buffer)) >= config.BufferLimit {
break
}
}
// if streaming body has reasoning_content, buffer maybe empty
log.Debugf("current content piece: %s", buffer)
if len(buffer) == 0 {
return
}
ctx.SetContext("during_call", true)
log.Debugf("current content piece: %s", buffer)
checkService := config.GetResponseCheckService(consumer)
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, buffer, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
if ctx.GetContext("end_of_stream_received").(bool) {
proxywasm.ResumeHttpResponse()
}
}
}
}
if !ctx.GetContext("risk_detected").(bool) {
unifiedChunk := wrapper.UnifySSEChunk(data)
hasTrailingSeparator := bytes.HasSuffix(unifiedChunk, []byte("\n\n"))
trimmedChunk := bytes.TrimSpace(unifiedChunk)
chunks := bytes.Split(trimmedChunk, []byte("\n\n"))
// Filter out empty chunks
nonEmptyChunks := make([][]byte, 0, len(chunks))
for _, chunk := range chunks {
if len(chunk) > 0 {
nonEmptyChunks = append(nonEmptyChunks, chunk)
}
}
// Restore separators
for i := range len(nonEmptyChunks) - 1 {
nonEmptyChunks[i] = append(nonEmptyChunks[i], []byte("\n\n")...)
}
if hasTrailingSeparator && len(nonEmptyChunks) > 0 {
nonEmptyChunks[len(nonEmptyChunks)-1] = append(nonEmptyChunks[len(nonEmptyChunks)-1], []byte("\n\n")...)
}
for _, chunk := range nonEmptyChunks {
ctx.PushBuffer(chunk)
}
// for _, chunk := range bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) {
// ctx.PushBuffer([]byte(string(chunk) + "\n\n"))
// }
ctx.SetContext("end_of_stream_received", endOfStream)
if !ctx.GetContext("during_call").(bool) {
singleCall()
}
} else if endOfStream {
proxywasm.ResumeHttpResponse()
}
return []byte{}
}
func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
log.Debugf("checking response body...")
startTime := time.Now().UnixMilli()
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
isStreamingResponse := strings.Contains(contentType, "event-stream")
var content string
if isStreamingResponse {
content = utils.ExtractMessageFromStreamingBody(body, config.ResponseStreamContentJsonPath)
} else {
content = gjson.GetBytes(body, config.ResponseContentJsonPath).String()
}
log.Debugf("Raw response content is: %s", content)
if len(content) == 0 {
log.Info("response content is empty. skip")
return types.ActionContinue
}
contentIndex := 0
sessionID, _ := utils.GenerateHexID(20)
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
proxywasm.ResumeHttpResponse()
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at response phase")
proxywasm.ResumeHttpResponse()
return
}
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if contentIndex >= len(content) {
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "response pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpResponse()
} else {
singleCall()
}
return
}
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if isStreamingResponse {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
config.IncrementCounter("ai_sec_response_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "response deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
var nextContentIndex int
if contentIndex+cfg.LengthLimit >= len(content) {
nextContentIndex = len(content)
} else {
nextContentIndex = contentIndex + cfg.LengthLimit
}
contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
checkService := config.GetResponseCheckService(consumer)
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, contentPiece, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpResponse()
}
}
singleCall()
return types.ActionPause
}

View File

@@ -0,0 +1,67 @@
package multi_modal_guard
import (
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
common_text "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
func OnHttpRequestHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
return types.ActionContinue
}
func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
return text.HandleTextGenerationRequestBody(ctx, config, body)
}
func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
switch config.ApiType {
case cfg.ApiTextGeneration:
return common_text.HandleTextGenerationResponseHeader(ctx, config)
case cfg.ApiImageGeneration:
switch config.ProviderType {
case cfg.ProviderOpenAI, cfg.ProviderQwen:
return image.HandleImageGenerationResponseHeader(ctx, config)
default:
log.Errorf("[on response header] image generation api don't support provider: %s", config.ProviderType)
return types.ActionContinue
}
default:
log.Errorf("[on response header] multi_modal_guard don't support api: %s", config.ApiType)
return types.ActionContinue
}
}
func OnHttpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
switch config.ApiType {
case cfg.ApiTextGeneration:
return common_text.HandleTextGenerationStreamingResponseBody(ctx, config, data, endOfStream)
default:
log.Errorf("[on streaming response body] multi_modal_guard don't support api: %s", config.ApiType)
return data
}
}
func OnHttpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
switch config.ApiType {
case cfg.ApiTextGeneration:
return common_text.HandleTextGenerationResponseBody(ctx, config, body)
case cfg.ApiImageGeneration:
switch config.ProviderType {
case cfg.ProviderOpenAI:
return image.HandleOpenAIImageGenerationResponseBody(ctx, config, body)
case cfg.ProviderQwen:
return image.HandleQwenImageGenerationResponseBody(ctx, config, body)
default:
log.Errorf("[on response body] image generation api don't support provider: %s", config.ProviderType)
return types.ActionContinue
}
default:
log.Errorf("[on response body] multi_modal_guard don't support api: %s", config.ApiType)
return types.ActionContinue
}
}

View File

@@ -0,0 +1,22 @@
package image
import (
"strings"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
func HandleImageGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
ctx.SetContext("risk_detected", false)
if strings.Contains(contentType, "text/event-stream") {
ctx.DontReadResponseBody()
return types.ActionContinue
} else {
ctx.BufferResponseBody()
return types.HeaderStopIteration
}
}

View File

@@ -0,0 +1,111 @@
package image
import (
"encoding/json"
"net/http"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
type ImageItemForOpenAI struct {
Content string
Type string // URL or BASE64
}
func getOpenAIImageResults(body []byte) []ImageItemForOpenAI {
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
result := []ImageItemForOpenAI{}
for _, part := range gjson.GetBytes(body, "data").Array() {
if url := part.Get("url").String(); url != "" {
result = append(result, ImageItemForOpenAI{
Content: url,
Type: "URL",
})
}
if b64 := part.Get("b64_json").String(); b64 != "" {
result = append(result, ImageItemForOpenAI{
Content: b64,
Type: "BASE64",
})
}
}
return result
}
func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
log.Debugf("checking response body...")
checkImageService := config.GetResponseImageCheckService(consumer)
startTime := time.Now().UnixMilli()
imgResults := getOpenAIImageResults(body)
if len(imgResults) == 0 {
return types.ActionContinue
}
imageIndex := 0
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
imageIndex += 1
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
if imageIndex < len(imgResults) {
singleCall()
} else {
proxywasm.ResumeHttpResponse()
}
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
if imageIndex < len(imgResults) {
singleCall()
} else {
proxywasm.ResumeHttpResponse()
}
return
}
endTime := time.Now().UnixMilli()
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if imageIndex >= len(imgResults) {
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpResponse()
} else {
singleCall()
}
return
}
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte("illegal image"), -1)
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
img := imgResults[imageIndex]
imgUrl := ""
imgBase64 := ""
if img.Type == "BASE64" {
imgBase64 = img.Content
} else {
imgUrl = img.Content
}
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpResponse()
}
}
singleCall()
return types.ActionPause
}

View File

@@ -0,0 +1,134 @@
package image
import (
"encoding/json"
"net/http"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
func getQwenImageUrls(body []byte) []string {
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
result := []string{}
// 文生图/文生图v1/文生图v2/通用图像编辑2.5/通用图像编辑2.1/涂鸦作画/图像局部重绘/人像风格重绘
// 虚拟模特/图像背景生成/人物写真FaceChain/文生图StableDiffusion/文生图FLUX/文字纹理生成API
for _, part := range gjson.GetBytes(body, "output.results").Array() {
if url := part.Get("url").String(); url != "" {
result = append(result, url)
}
}
// 图像编辑
for _, part := range gjson.GetBytes(body, "output.choices.0.message.content").Array() {
if url := part.Get("image").String(); url != "" {
result = append(result, url)
}
}
// 图像翻译/AI试衣OutfitAnyone
if url := gjson.GetBytes(body, "output.image_url").String(); url != "" {
result = append(result, url)
}
// 图像画面扩展/(part of)人物实例分割/图像擦除补全
if url := gjson.GetBytes(body, "output.output_image_url").String(); url != "" {
result = append(result, url)
}
// 鞋靴模特
if url := gjson.GetBytes(body, "output.result_url").String(); url != "" {
result = append(result, url)
}
// 创意海报生成
for _, part := range gjson.GetBytes(body, "output.render_urls").Array() {
if url := part.String(); url != "" {
result = append(result, url)
}
}
for _, part := range gjson.GetBytes(body, "output.bg_urls").Array() {
if url := part.String(); url != "" {
result = append(result, url)
}
}
// 人物实例分割
if url := gjson.GetBytes(body, "output.output_vis_image_url").String(); url != "" {
result = append(result, url)
}
// 文字变形API
for _, part := range gjson.GetBytes(body, "output.results").Array() {
if url := part.Get("png_url").String(); url != "" {
result = append(result, url)
}
if url := part.Get("svg_url").String(); url != "" {
result = append(result, url)
}
}
return result
}
func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
log.Debugf("checking response body...")
checkImageService := config.GetResponseImageCheckService(consumer)
startTime := time.Now().UnixMilli()
imgUrls := getQwenImageUrls(body)
if len(imgUrls) == 0 {
return types.ActionContinue
}
imageIndex := 0
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
imageIndex += 1
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
if imageIndex < len(imgUrls) {
singleCall()
} else {
proxywasm.ResumeHttpResponse()
}
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
if imageIndex < len(imgUrls) {
singleCall()
} else {
proxywasm.ResumeHttpResponse()
}
return
}
endTime := time.Now().UnixMilli()
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if imageIndex >= len(imgUrls) {
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpResponse()
} else {
singleCall()
}
return
}
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte("illegal image"), -1)
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
imgUrl := imgUrls[imageIndex]
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, "")
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpResponse()
}
}
singleCall()
return types.ActionPause
}

View File

@@ -0,0 +1,231 @@
package text
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
type ImageItemForOpenAI struct {
Content string
Type string // URL or BASE64
}
func parseContent(json gjson.Result) (text string, images []ImageItemForOpenAI) {
images = []ImageItemForOpenAI{}
if json.IsArray() {
for _, item := range json.Array() {
switch item.Get("type").String() {
case "text":
text += item.Get("text").String()
case "image_url":
imgContent := item.Get("image_url.url").String()
if strings.HasPrefix(imgContent, "data:image") {
images = append(images, ImageItemForOpenAI{
Content: imgContent,
Type: "BASE64",
})
} else {
images = append(images, ImageItemForOpenAI{
Content: imgContent,
Type: "URL",
})
}
}
}
} else {
text = json.String()
}
return text, images
}
func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
checkService := config.GetRequestCheckService(consumer)
checkImageService := config.GetRequestImageCheckService(consumer)
startTime := time.Now().UnixMilli()
// content := gjson.GetBytes(body, config.RequestContentJsonPath).String()
content, images := parseContent(gjson.GetBytes(body, config.RequestContentJsonPath))
log.Debugf("Raw request content is: %s", content)
if len(content) == 0 && len(images) == 0 {
log.Info("request content is empty. skip")
return types.ActionContinue
}
contentIndex := 0
imageIndex := 0
sessionID, _ := utils.GenerateHexID(20)
var singleCall func()
var singleCallForImage func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
proxywasm.ResumeHttpRequest()
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
proxywasm.ResumeHttpRequest()
return
}
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if contentIndex >= len(content) {
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
if len(images) > 0 && config.CheckRequestImage {
singleCallForImage()
} else {
proxywasm.ResumeHttpRequest()
}
} else {
singleCall()
}
return
}
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
var nextContentIndex int
if contentIndex+cfg.LengthLimit >= len(content) {
nextContentIndex = len(content)
} else {
nextContentIndex = contentIndex + cfg.LengthLimit
}
contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpRequest()
}
}
callbackForImage := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
imageIndex += 1
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
if imageIndex < len(images) {
singleCallForImage()
} else {
proxywasm.ResumeHttpRequest()
}
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
if imageIndex < len(images) {
singleCallForImage()
} else {
proxywasm.ResumeHttpRequest()
}
return
}
endTime := time.Now().UnixMilli()
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if imageIndex >= len(images) {
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpRequest()
} else {
singleCallForImage()
}
return
}
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCallForImage = func() {
img := images[imageIndex]
imgUrl := ""
imgBase64 := ""
if img.Type == "BASE64" {
imgBase64 = img.Content
} else {
imgUrl = img.Content
}
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpRequest()
}
}
if len(content) > 0 {
singleCall()
} else {
singleCallForImage()
}
return types.ActionPause
}

View File

@@ -0,0 +1,48 @@
package text_moderation_plus
import (
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
common_text "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus/text"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
func OnHttpRequestHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
return types.ActionContinue
}
func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
return text.HandleTextGenerationRequestBody(ctx, config, body)
}
func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
switch config.ApiType {
case cfg.ApiTextGeneration:
return common_text.HandleTextGenerationResponseHeader(ctx, config)
default:
log.Errorf("text_moderation_plus don't support api: %s", config.ApiType)
return types.ActionContinue
}
}
func OnHttpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
switch config.ApiType {
case cfg.ApiTextGeneration:
return common_text.HandleTextGenerationStreamingResponseBody(ctx, config, data, endOfStream)
default:
log.Errorf("text_moderation_plus don't support api: %s", config.ApiType)
return data
}
}
func OnHttpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
switch config.ApiType {
case cfg.ApiTextGeneration:
return common_text.HandleTextGenerationResponseBody(ctx, config, body)
default:
log.Errorf("text_moderation_plus don't support api: %s", config.ApiType)
return types.ActionContinue
}
}

View File

@@ -0,0 +1,104 @@
package text
import (
"encoding/json"
"fmt"
"net/http"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
startTime := time.Now().UnixMilli()
content := gjson.GetBytes(body, config.RequestContentJsonPath).String()
log.Debugf("Raw request content is: %s", content)
if len(content) == 0 {
log.Info("request content is empty. skip")
return types.ActionContinue
}
contentIndex := 0
sessionID, _ := utils.GenerateHexID(20)
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
proxywasm.ResumeHttpRequest()
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at request phase")
proxywasm.ResumeHttpRequest()
return
}
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if contentIndex >= len(content) {
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpRequest()
} else {
singleCall()
}
return
}
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
var nextContentIndex int
if contentIndex+cfg.LengthLimit >= len(content) {
nextContentIndex = len(content)
} else {
nextContentIndex = contentIndex + cfg.LengthLimit
}
contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex
checkService := config.GetRequestCheckService(consumer)
path, headers, body := common.GenerateRequestForText(config, cfg.TextModerationPlus, checkService, contentPiece, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpRequest()
}
}
singleCall()
return types.ActionPause
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -18,6 +18,8 @@ import (
"encoding/json"
"testing"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
@@ -143,16 +145,16 @@ func TestParseConfig(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*AISecurityConfig)
require.Equal(t, "test-ak", securityConfig.ak)
require.Equal(t, "test-sk", securityConfig.sk)
require.Equal(t, true, securityConfig.checkRequest)
require.Equal(t, true, securityConfig.checkResponse)
require.Equal(t, "high", securityConfig.contentModerationLevelBar)
require.Equal(t, "high", securityConfig.promptAttackLevelBar)
require.Equal(t, "S3", securityConfig.sensitiveDataLevelBar)
require.Equal(t, uint32(2000), securityConfig.timeout)
require.Equal(t, 1000, securityConfig.bufferLimit)
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, "test-ak", securityConfig.AK)
require.Equal(t, "test-sk", securityConfig.SK)
require.Equal(t, true, securityConfig.CheckRequest)
require.Equal(t, true, securityConfig.CheckResponse)
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
require.Equal(t, uint32(2000), securityConfig.Timeout)
require.Equal(t, 1000, securityConfig.BufferLimit)
})
// 测试仅检查请求的配置
@@ -164,12 +166,12 @@ func TestParseConfig(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*AISecurityConfig)
require.Equal(t, true, securityConfig.checkRequest)
require.Equal(t, false, securityConfig.checkResponse)
require.Equal(t, "high", securityConfig.contentModerationLevelBar)
require.Equal(t, "high", securityConfig.promptAttackLevelBar)
require.Equal(t, "S3", securityConfig.sensitiveDataLevelBar)
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, true, securityConfig.CheckRequest)
require.Equal(t, false, securityConfig.CheckResponse)
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
})
// 测试缺少必需字段的配置
@@ -202,13 +204,13 @@ func TestParseConfig(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*AISecurityConfig)
require.Equal(t, "llm_query_moderation", securityConfig.getRequestCheckService("aaaa"))
require.Equal(t, "llm_query_moderation_1", securityConfig.getRequestCheckService("aaa"))
require.Equal(t, "llm_response_moderation", securityConfig.getResponseCheckService("bb"))
require.Equal(t, "llm_response_moderation_1", securityConfig.getResponseCheckService("bbb-prefix-test"))
require.Equal(t, "high", securityConfig.getMaliciousUrlLevelBar("cc"))
require.Equal(t, "low", securityConfig.getMaliciousUrlLevelBar("ccc-regexp-test"))
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, "llm_query_moderation", securityConfig.GetRequestCheckService("aaaa"))
require.Equal(t, "llm_query_moderation_1", securityConfig.GetRequestCheckService("aaa"))
require.Equal(t, "llm_response_moderation", securityConfig.GetResponseCheckService("bb"))
require.Equal(t, "llm_response_moderation_1", securityConfig.GetResponseCheckService("bbb-prefix-test"))
require.Equal(t, "high", securityConfig.GetMaliciousUrlLevelBar("cc"))
require.Equal(t, "low", securityConfig.GetMaliciousUrlLevelBar("ccc-regexp-test"))
})
})
}
@@ -385,62 +387,27 @@ func TestOnHttpResponseHeaders(t *testing.T) {
func TestRiskLevelFunctions(t *testing.T) {
// 测试风险等级转换函数
t.Run("risk level conversion", func(t *testing.T) {
require.Equal(t, 4, levelToInt(MaxRisk))
require.Equal(t, 3, levelToInt(HighRisk))
require.Equal(t, 2, levelToInt(MediumRisk))
require.Equal(t, 1, levelToInt(LowRisk))
require.Equal(t, 0, levelToInt(NoRisk))
require.Equal(t, -1, levelToInt("invalid"))
require.Equal(t, 4, cfg.LevelToInt(cfg.MaxRisk))
require.Equal(t, 3, cfg.LevelToInt(cfg.HighRisk))
require.Equal(t, 2, cfg.LevelToInt(cfg.MediumRisk))
require.Equal(t, 1, cfg.LevelToInt(cfg.LowRisk))
require.Equal(t, 0, cfg.LevelToInt(cfg.NoRisk))
require.Equal(t, -1, cfg.LevelToInt("invalid"))
})
// 测试风险等级比较
t.Run("risk level comparison", func(t *testing.T) {
require.True(t, levelToInt(HighRisk) >= levelToInt(MediumRisk))
require.True(t, levelToInt(MediumRisk) >= levelToInt(LowRisk))
require.True(t, levelToInt(LowRisk) >= levelToInt(NoRisk))
require.False(t, levelToInt(LowRisk) >= levelToInt(HighRisk))
require.True(t, cfg.LevelToInt(cfg.HighRisk) >= cfg.LevelToInt(cfg.MediumRisk))
require.True(t, cfg.LevelToInt(cfg.MediumRisk) >= cfg.LevelToInt(cfg.LowRisk))
require.True(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.NoRisk))
require.False(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.HighRisk))
})
}
func TestUtilityFunctions(t *testing.T) {
// 测试URL编码函数
t.Run("url encoding", func(t *testing.T) {
original := "test+string:with=special&chars@$"
encoded := urlEncoding(original)
require.NotEqual(t, original, encoded)
require.Contains(t, encoded, "%2B") // + 应该被编码
require.Contains(t, encoded, "%3A") // : 应该被编码
require.Contains(t, encoded, "%3D") // = 应该被编码
require.Contains(t, encoded, "%26") // & 应该被编码
})
// 测试HMAC-SHA1签名函数
t.Run("hmac sha1", func(t *testing.T) {
message := "test message"
secret := "test secret"
signature := hmacSha1(message, secret)
require.NotEmpty(t, signature)
require.NotEqual(t, message, signature)
})
// 测试签名生成函数
t.Run("signature generation", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
params := map[string]string{
"key1": "value1",
"key2": "value2",
}
secret := "test-secret"
signature := getSign(params, secret)
require.NotEmpty(t, signature)
})
// 测试十六进制ID生成函数
t.Run("hex id generation", func(t *testing.T) {
id, err := generateHexID(16)
id, err := utils.GenerateHexID(16)
require.NoError(t, err)
require.Len(t, id, 16)
require.Regexp(t, "^[0-9a-f]+$", id)
@@ -448,7 +415,7 @@ func TestUtilityFunctions(t *testing.T) {
// 测试随机ID生成函数
t.Run("random id generation", func(t *testing.T) {
id := generateRandomID()
id := utils.GenerateRandomChatID()
require.NotEmpty(t, id)
require.Contains(t, id, "chatcmpl-")
require.Len(t, id, 38) // "chatcmpl-" + 29 random chars

View File

@@ -0,0 +1,43 @@
package utils
import (
"bytes"
"crypto/rand"
"encoding/hex"
mrand "math/rand"
"strings"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
func GenerateHexID(length int) (string, error) {
bytes := make([]byte, length/2)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
func GenerateRandomChatID() string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, 29)
for i := range b {
b[i] = charset[mrand.Intn(len(charset))]
}
return "chatcmpl-" + string(b)
}
func ExtractMessageFromStreamingBody(data []byte, jsonPath string) string {
chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n"))
strChunks := []string{}
for _, chunk := range chunks {
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
strChunks = append(strChunks, gjson.GetBytes(chunk, jsonPath).String())
}
return strings.Join(strChunks, "")
}
func GetConsumer(ctx wrapper.HttpContext) string {
return ctx.GetStringContext("consumer", "")
}