mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 15:10:54 +08:00
[feat] ai-security-guard refactor & support checking multimoadl input (#3075)
This commit is contained in:
585
plugins/wasm-go/extensions/ai-security-guard/config/config.go
Normal file
585
plugins/wasm-go/extensions/ai-security-guard/config/config.go
Normal file
@@ -0,0 +1,585 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxRisk = "max"
|
||||
HighRisk = "high"
|
||||
MediumRisk = "medium"
|
||||
LowRisk = "low"
|
||||
NoRisk = "none"
|
||||
|
||||
S4Sensitive = "s4"
|
||||
S3Sensitive = "s3"
|
||||
S2Sensitive = "s2"
|
||||
S1Sensitive = "s1"
|
||||
NoSensitive = "s0"
|
||||
|
||||
ContentModerationType = "contentModeration"
|
||||
PromptAttackType = "promptAttack"
|
||||
SensitiveDataType = "sensitiveData"
|
||||
MaliciousUrlDataType = "maliciousUrl"
|
||||
ModelHallucinationDataType = "modelHallucination"
|
||||
|
||||
// Default configurations
|
||||
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
|
||||
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
|
||||
|
||||
DefaultDenyCode = 200
|
||||
DefaultDenyMessage = "很抱歉,我无法回答您的问题"
|
||||
DefaultTimeout = 2000
|
||||
|
||||
AliyunUserAgent = "CIPFrom/AIGateway"
|
||||
LengthLimit = 1800
|
||||
|
||||
DefaultRequestCheckService = "llm_query_moderation"
|
||||
DefaultResponseCheckService = "llm_response_moderation"
|
||||
DefaultRequestJsonPath = "messages.@reverse.0.content"
|
||||
DefaultResponseJsonPath = "choices.0.message.content"
|
||||
DefaultStreamingResponseJsonPath = "choices.0.delta.content"
|
||||
|
||||
// Actions
|
||||
MultiModalGuard = "MultiModalGuard"
|
||||
MultiModalGuardForBase64 = "MultiModalGuardForBase64"
|
||||
TextModerationPlus = "TextModerationPlus"
|
||||
|
||||
// Services
|
||||
DefaultMultiModalGuardTextInputCheckService = "query_security_check"
|
||||
DefaultMultiModalGuardTextOutputCheckService = "response_security_check"
|
||||
DefaultMultiModalGuardImageInputCheckService = "img_query_security_check"
|
||||
|
||||
DefaultTextModerationPlusTextInputCheckService = "llm_query_moderation"
|
||||
DefaultTextModerationPlusTextOutputCheckService = "llm_response_moderation"
|
||||
)
|
||||
|
||||
// api types
|
||||
|
||||
const (
|
||||
ApiTextGeneration = "text_generation"
|
||||
ApiImageGeneration = "image_generation"
|
||||
)
|
||||
|
||||
// provider types
|
||||
const (
|
||||
ProviderOpenAI = "openai"
|
||||
ProviderQwen = "qwen"
|
||||
ProviderComfyUI = "comfyui"
|
||||
)
|
||||
|
||||
type Response struct {
|
||||
Code int `json:"Code"`
|
||||
Message string `json:"Message"`
|
||||
RequestId string `json:"RequestId"`
|
||||
Data Data `json:"Data"`
|
||||
}
|
||||
|
||||
type Data struct {
|
||||
RiskLevel string `json:"RiskLevel,omitempty"`
|
||||
AttackLevel string `json:"AttackLevel,omitempty"`
|
||||
Result []Result `json:"Result,omitempty"`
|
||||
Advice []Advice `json:"Advice,omitempty"`
|
||||
Detail []Detail `json:"Detail,omitempty"`
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
RiskWords string `json:"RiskWords,omitempty"`
|
||||
Description string `json:"Description,omitempty"`
|
||||
Confidence float64 `json:"Confidence,omitempty"`
|
||||
Label string `json:"Label,omitempty"`
|
||||
}
|
||||
|
||||
type Advice struct {
|
||||
Answer string `json:"Answer,omitempty"`
|
||||
HitLabel string `json:"HitLabel,omitempty"`
|
||||
HitLibName string `json:"HitLibName,omitempty"`
|
||||
}
|
||||
|
||||
type Detail struct {
|
||||
Suggestion string `json:"Suggestion,omitempty"`
|
||||
Type string `json:"Type,omitempty"`
|
||||
Level string `json:"Level,omitempty"`
|
||||
}
|
||||
|
||||
type Matcher struct {
|
||||
Exact string
|
||||
Prefix string
|
||||
Re *regexp.Regexp
|
||||
}
|
||||
|
||||
func (m *Matcher) match(consumer string) bool {
|
||||
if m.Exact != "" {
|
||||
return consumer == m.Exact
|
||||
} else if m.Prefix != "" {
|
||||
return strings.HasPrefix(consumer, m.Prefix)
|
||||
} else if m.Re != nil {
|
||||
return m.Re.MatchString(consumer)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type AISecurityConfig struct {
|
||||
Client wrapper.HttpClient
|
||||
Host string
|
||||
AK string
|
||||
SK string
|
||||
Token string
|
||||
Action string
|
||||
CheckRequest bool
|
||||
CheckRequestImage bool
|
||||
RequestCheckService string
|
||||
RequestImageCheckService string
|
||||
RequestContentJsonPath string
|
||||
CheckResponse bool
|
||||
ResponseCheckService string
|
||||
ResponseImageCheckService string
|
||||
ResponseContentJsonPath string
|
||||
ResponseStreamContentJsonPath string
|
||||
DenyCode int64
|
||||
DenyMessage string
|
||||
ProtocolOriginal bool
|
||||
RiskLevelBar string
|
||||
ContentModerationLevelBar string
|
||||
PromptAttackLevelBar string
|
||||
SensitiveDataLevelBar string
|
||||
MaliciousUrlLevelBar string
|
||||
ModelHallucinationLevelBar string
|
||||
Timeout uint32
|
||||
BufferLimit int
|
||||
Metrics map[string]proxywasm.MetricCounter
|
||||
ConsumerRequestCheckService []map[string]interface{}
|
||||
ConsumerResponseCheckService []map[string]interface{}
|
||||
ConsumerRiskLevel []map[string]interface{}
|
||||
// text_generation, image_generation, etc.
|
||||
ApiType string
|
||||
// openai, qwen, comfyui, etc.
|
||||
ProviderType string
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) Parse(json gjson.Result) error {
|
||||
serviceName := json.Get("serviceName").String()
|
||||
servicePort := json.Get("servicePort").Int()
|
||||
serviceHost := json.Get("serviceHost").String()
|
||||
config.Host = serviceHost
|
||||
if serviceName == "" || servicePort == 0 || serviceHost == "" {
|
||||
return errors.New("invalid service config")
|
||||
}
|
||||
config.AK = json.Get("accessKey").String()
|
||||
config.SK = json.Get("secretKey").String()
|
||||
if config.AK == "" || config.SK == "" {
|
||||
return errors.New("invalid AK/SK config")
|
||||
}
|
||||
config.Token = json.Get("securityToken").String()
|
||||
// set action
|
||||
if obj := json.Get("action"); obj.Exists() {
|
||||
config.Action = json.Get("action").String()
|
||||
} else {
|
||||
config.Action = TextModerationPlus
|
||||
}
|
||||
// set default values
|
||||
config.SetDefaultValues()
|
||||
// set values
|
||||
if obj := json.Get("riskLevelBar"); obj.Exists() {
|
||||
config.RiskLevelBar = obj.String()
|
||||
}
|
||||
if obj := json.Get("requestCheckService"); obj.Exists() {
|
||||
config.RequestCheckService = obj.String()
|
||||
}
|
||||
if obj := json.Get("requestImageCheckService"); obj.Exists() {
|
||||
config.RequestImageCheckService = obj.String()
|
||||
}
|
||||
if obj := json.Get("responseCheckService"); obj.Exists() {
|
||||
config.ResponseCheckService = obj.String()
|
||||
}
|
||||
if obj := json.Get("responseImageCheckService"); obj.Exists() {
|
||||
config.ResponseImageCheckService = obj.String()
|
||||
}
|
||||
config.CheckRequest = json.Get("checkRequest").Bool()
|
||||
config.CheckRequestImage = json.Get("checkRequestImage").Bool()
|
||||
config.CheckResponse = json.Get("checkResponse").Bool()
|
||||
config.ProtocolOriginal = json.Get("protocol").String() == "original"
|
||||
config.DenyMessage = json.Get("denyMessage").String()
|
||||
if obj := json.Get("denyCode"); obj.Exists() {
|
||||
config.DenyCode = obj.Int()
|
||||
}
|
||||
if obj := json.Get("requestContentJsonPath"); obj.Exists() {
|
||||
config.RequestContentJsonPath = obj.String()
|
||||
}
|
||||
if obj := json.Get("responseContentJsonPath"); obj.Exists() {
|
||||
config.ResponseContentJsonPath = obj.String()
|
||||
}
|
||||
if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() {
|
||||
config.ResponseStreamContentJsonPath = obj.String()
|
||||
}
|
||||
if obj := json.Get("contentModerationLevelBar"); obj.Exists() {
|
||||
config.ContentModerationLevelBar = obj.String()
|
||||
if LevelToInt(config.ContentModerationLevelBar) <= 0 {
|
||||
return errors.New("invalid contentModerationLevelBar, value must be one of [max, high, medium, low]")
|
||||
}
|
||||
}
|
||||
if obj := json.Get("promptAttackLevelBar"); obj.Exists() {
|
||||
config.PromptAttackLevelBar = obj.String()
|
||||
if LevelToInt(config.PromptAttackLevelBar) <= 0 {
|
||||
return errors.New("invalid promptAttackLevelBar, value must be one of [max, high, medium, low]")
|
||||
}
|
||||
}
|
||||
if obj := json.Get("sensitiveDataLevelBar"); obj.Exists() {
|
||||
config.SensitiveDataLevelBar = obj.String()
|
||||
if LevelToInt(config.SensitiveDataLevelBar) <= 0 {
|
||||
return errors.New("invalid sensitiveDataLevelBar, value must be one of [S4, S3, S2, S1]")
|
||||
}
|
||||
}
|
||||
if obj := json.Get("modelHallucinationLevelBar"); obj.Exists() {
|
||||
config.ModelHallucinationLevelBar = obj.String()
|
||||
if LevelToInt(config.ModelHallucinationLevelBar) <= 0 {
|
||||
return errors.New("invalid modelHallucinationLevelBar, value must be one of [max, high, medium, low]")
|
||||
}
|
||||
}
|
||||
if obj := json.Get("maliciousUrlLevelBar"); obj.Exists() {
|
||||
config.MaliciousUrlLevelBar = obj.String()
|
||||
if LevelToInt(config.MaliciousUrlLevelBar) <= 0 {
|
||||
return errors.New("invalid maliciousUrlLevelBar, value must be one of [max, high, medium, low]")
|
||||
}
|
||||
}
|
||||
if obj := json.Get("timeout"); obj.Exists() {
|
||||
config.Timeout = uint32(obj.Int())
|
||||
}
|
||||
if obj := json.Get("bufferLimit"); obj.Exists() {
|
||||
config.BufferLimit = int(obj.Int())
|
||||
}
|
||||
if obj := json.Get("consumerRequestCheckService"); obj.Exists() {
|
||||
for _, item := range json.Get("consumerRequestCheckService").Array() {
|
||||
m := make(map[string]interface{})
|
||||
for k, v := range item.Map() {
|
||||
m[k] = v.Value()
|
||||
}
|
||||
consumerName, ok1 := m["name"]
|
||||
matchType, ok2 := m["matchType"]
|
||||
if !ok1 || !ok2 {
|
||||
continue
|
||||
}
|
||||
switch fmt.Sprint(matchType) {
|
||||
case "exact":
|
||||
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
|
||||
case "prefix":
|
||||
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
|
||||
case "regexp":
|
||||
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
|
||||
}
|
||||
config.ConsumerRequestCheckService = append(config.ConsumerRequestCheckService, m)
|
||||
}
|
||||
}
|
||||
if obj := json.Get("consumerResponseCheckService"); obj.Exists() {
|
||||
for _, item := range json.Get("consumerResponseCheckService").Array() {
|
||||
m := make(map[string]interface{})
|
||||
for k, v := range item.Map() {
|
||||
m[k] = v.Value()
|
||||
}
|
||||
consumerName, ok1 := m["name"]
|
||||
matchType, ok2 := m["matchType"]
|
||||
if !ok1 || !ok2 {
|
||||
continue
|
||||
}
|
||||
switch fmt.Sprint(matchType) {
|
||||
case "exact":
|
||||
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
|
||||
case "prefix":
|
||||
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
|
||||
case "regexp":
|
||||
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
|
||||
}
|
||||
config.ConsumerResponseCheckService = append(config.ConsumerResponseCheckService, m)
|
||||
}
|
||||
}
|
||||
if obj := json.Get("consumerRiskLevel"); obj.Exists() {
|
||||
for _, item := range json.Get("consumerRiskLevel").Array() {
|
||||
m := make(map[string]interface{})
|
||||
for k, v := range item.Map() {
|
||||
m[k] = v.Value()
|
||||
}
|
||||
consumerName, ok1 := m["name"]
|
||||
matchType, ok2 := m["matchType"]
|
||||
if !ok1 || !ok2 {
|
||||
continue
|
||||
}
|
||||
switch fmt.Sprint(matchType) {
|
||||
case "exact":
|
||||
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
|
||||
case "prefix":
|
||||
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
|
||||
case "regexp":
|
||||
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
|
||||
}
|
||||
config.ConsumerRiskLevel = append(config.ConsumerRiskLevel, m)
|
||||
}
|
||||
}
|
||||
if obj := json.Get("apiType"); obj.Exists() {
|
||||
config.ApiType = obj.String()
|
||||
}
|
||||
if obj := json.Get("providerType"); obj.Exists() {
|
||||
config.ProviderType = obj.String()
|
||||
}
|
||||
config.Client = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: serviceName,
|
||||
Port: servicePort,
|
||||
Host: serviceHost,
|
||||
})
|
||||
config.Metrics = make(map[string]proxywasm.MetricCounter)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) SetDefaultValues() {
|
||||
switch config.Action {
|
||||
case TextModerationPlus:
|
||||
config.RequestCheckService = DefaultTextModerationPlusTextInputCheckService
|
||||
config.ResponseCheckService = DefaultTextModerationPlusTextOutputCheckService
|
||||
case MultiModalGuard:
|
||||
config.RequestCheckService = DefaultMultiModalGuardTextInputCheckService
|
||||
config.RequestImageCheckService = DefaultMultiModalGuardImageInputCheckService
|
||||
config.ResponseCheckService = DefaultMultiModalGuardTextOutputCheckService
|
||||
}
|
||||
config.RiskLevelBar = HighRisk
|
||||
config.DenyCode = DefaultDenyCode
|
||||
config.RequestContentJsonPath = DefaultRequestJsonPath
|
||||
config.ResponseContentJsonPath = DefaultResponseJsonPath
|
||||
config.ResponseStreamContentJsonPath = DefaultStreamingResponseJsonPath
|
||||
config.ContentModerationLevelBar = MaxRisk
|
||||
config.PromptAttackLevelBar = MaxRisk
|
||||
config.SensitiveDataLevelBar = S4Sensitive
|
||||
config.ModelHallucinationLevelBar = MaxRisk
|
||||
config.MaliciousUrlLevelBar = MaxRisk
|
||||
config.Timeout = DefaultTimeout
|
||||
config.BufferLimit = 1000
|
||||
config.ApiType = ApiTextGeneration
|
||||
config.ProviderType = ProviderOpenAI
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) IncrementCounter(metricName string, inc uint64) {
|
||||
counter, ok := config.Metrics[metricName]
|
||||
if !ok {
|
||||
counter = proxywasm.DefineCounterMetric(metricName)
|
||||
config.Metrics[metricName] = counter
|
||||
}
|
||||
counter.Increment(inc)
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetRequestCheckService(consumer string) string {
|
||||
result := config.RequestCheckService
|
||||
for _, obj := range config.ConsumerRequestCheckService {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if requestCheckService, ok := obj["requestCheckService"]; ok {
|
||||
result, _ = requestCheckService.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetRequestImageCheckService(consumer string) string {
|
||||
result := config.RequestImageCheckService
|
||||
for _, obj := range config.ConsumerRequestCheckService {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if requestCheckService, ok := obj["requestImageCheckService"]; ok {
|
||||
result, _ = requestCheckService.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetResponseCheckService(consumer string) string {
|
||||
result := config.ResponseCheckService
|
||||
for _, obj := range config.ConsumerResponseCheckService {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if responseCheckService, ok := obj["responseCheckService"]; ok {
|
||||
result, _ = responseCheckService.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetResponseImageCheckService(consumer string) string {
|
||||
result := config.ResponseImageCheckService
|
||||
for _, obj := range config.ConsumerResponseCheckService {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if responseCheckService, ok := obj["responseImageCheckService"]; ok {
|
||||
result, _ = responseCheckService.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetRiskLevelBar(consumer string) string {
|
||||
result := config.RiskLevelBar
|
||||
for _, obj := range config.ConsumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if riskLevelBar, ok := obj["riskLevelBar"]; ok {
|
||||
result, _ = riskLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetContentModerationLevelBar(consumer string) string {
|
||||
result := config.ContentModerationLevelBar
|
||||
for _, obj := range config.ConsumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if contentModerationLevelBar, ok := obj["contentModerationLevelBar"]; ok {
|
||||
result, _ = contentModerationLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetPromptAttackLevelBar(consumer string) string {
|
||||
result := config.PromptAttackLevelBar
|
||||
for _, obj := range config.ConsumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if promptAttackLevelBar, ok := obj["promptAttackLevelBar"]; ok {
|
||||
result, _ = promptAttackLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetSensitiveDataLevelBar(consumer string) string {
|
||||
result := config.SensitiveDataLevelBar
|
||||
for _, obj := range config.ConsumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if sensitiveDataLevelBar, ok := obj["sensitiveDataLevelBar"]; ok {
|
||||
result, _ = sensitiveDataLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetMaliciousUrlLevelBar(consumer string) string {
|
||||
result := config.MaliciousUrlLevelBar
|
||||
for _, obj := range config.ConsumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if maliciousUrlLevelBar, ok := obj["maliciousUrlLevelBar"]; ok {
|
||||
result, _ = maliciousUrlLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetModelHallucinationLevelBar(consumer string) string {
|
||||
result := config.ModelHallucinationLevelBar
|
||||
for _, obj := range config.ConsumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if modelHallucinationLevelBar, ok := obj["modelHallucinationLevelBar"]; ok {
|
||||
result, _ = modelHallucinationLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func LevelToInt(riskLevel string) int {
|
||||
// First check against our defined constants
|
||||
switch strings.ToLower(riskLevel) {
|
||||
case MaxRisk, S4Sensitive:
|
||||
return 4
|
||||
case HighRisk, S3Sensitive:
|
||||
return 3
|
||||
case MediumRisk, S2Sensitive:
|
||||
return 2
|
||||
case LowRisk, S1Sensitive:
|
||||
return 1
|
||||
case NoRisk, NoSensitive:
|
||||
return 0
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
func IsRiskLevelAcceptable(action string, data Data, config AISecurityConfig, consumer string) bool {
|
||||
if action == MultiModalGuard || action == MultiModalGuardForBase64 {
|
||||
// Check top-level risk levels for MultiModalGuard
|
||||
if LevelToInt(data.RiskLevel) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
// Also check AttackLevel for prompt attack detection
|
||||
if LevelToInt(data.AttackLevel) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check detailed results for backward compatibility
|
||||
for _, detail := range data.Detail {
|
||||
switch detail.Type {
|
||||
case ContentModerationType:
|
||||
if LevelToInt(detail.Level) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
case PromptAttackType:
|
||||
if LevelToInt(detail.Level) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
case SensitiveDataType:
|
||||
if LevelToInt(detail.Level) >= LevelToInt(config.GetSensitiveDataLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
case MaliciousUrlDataType:
|
||||
if LevelToInt(detail.Level) >= LevelToInt(config.GetMaliciousUrlLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
case ModelHallucinationDataType:
|
||||
if LevelToInt(detail.Level) >= LevelToInt(config.GetModelHallucinationLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
} else {
|
||||
return LevelToInt(data.RiskLevel) < LevelToInt(config.GetRiskLevelBar(consumer))
|
||||
}
|
||||
}
|
||||
@@ -5,8 +5,8 @@ go 1.24.1
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
)
|
||||
@@ -20,5 +20,6 @@ require (
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -4,8 +4,12 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd h1:acTs8sqXf+qP+IypxFg3cu5Cluj7VT5BI+IDRlY5sag=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
@@ -24,6 +28,8 @@ github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
@@ -0,0 +1,249 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
ALGORITHM = "ACS3-HMAC-SHA256"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
httpMethod string
|
||||
canonicalUri string
|
||||
host string
|
||||
xAcsAction string
|
||||
xAcsVersion string
|
||||
headers map[string]string
|
||||
body []byte
|
||||
queryParam map[string]interface{}
|
||||
}
|
||||
|
||||
func newRequest(httpMethod, canonicalUri, host, xAcsAction, xAcsVersion string) *Request {
|
||||
req := &Request{
|
||||
httpMethod: httpMethod,
|
||||
canonicalUri: canonicalUri,
|
||||
host: host,
|
||||
xAcsAction: xAcsAction,
|
||||
xAcsVersion: xAcsVersion,
|
||||
headers: make(map[string]string),
|
||||
queryParam: make(map[string]interface{}),
|
||||
}
|
||||
req.headers["host"] = host
|
||||
req.headers["x-acs-action"] = xAcsAction
|
||||
req.headers["x-acs-version"] = xAcsVersion
|
||||
req.headers["x-acs-date"] = time.Now().UTC().Format(time.RFC3339)
|
||||
req.headers["x-acs-signature-nonce"] = uuid.New().String()
|
||||
return req
|
||||
}
|
||||
|
||||
func getAuthorization(req *Request, AccessKeyId, AccessKeySecret, SecurityToken string) {
|
||||
newQueryParams := make(map[string]interface{})
|
||||
processObject(newQueryParams, "", req.queryParam)
|
||||
req.queryParam = newQueryParams
|
||||
canonicalQueryString := ""
|
||||
keys := maps.Keys(req.queryParam)
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
v := req.queryParam[k]
|
||||
canonicalQueryString += percentCode(url.QueryEscape(k)) + "=" + percentCode(url.QueryEscape(fmt.Sprintf("%v", v))) + "&"
|
||||
}
|
||||
canonicalQueryString = strings.TrimSuffix(canonicalQueryString, "&")
|
||||
|
||||
var bodyContent []byte
|
||||
if req.body == nil {
|
||||
bodyContent = []byte("")
|
||||
} else {
|
||||
bodyContent = req.body
|
||||
}
|
||||
hashedRequestPayload := sha256Hex(bodyContent)
|
||||
req.headers["x-acs-content-sha256"] = hashedRequestPayload
|
||||
|
||||
if SecurityToken != "" {
|
||||
req.headers["x-acs-security-token"] = SecurityToken
|
||||
}
|
||||
|
||||
canonicalHeaders := ""
|
||||
signedHeaders := ""
|
||||
HeadersKeys := maps.Keys(req.headers)
|
||||
sort.Strings(HeadersKeys)
|
||||
for _, k := range HeadersKeys {
|
||||
lowerKey := strings.ToLower(k)
|
||||
if lowerKey == "host" || strings.HasPrefix(lowerKey, "x-acs-") || lowerKey == "content-type" {
|
||||
canonicalHeaders += lowerKey + ":" + req.headers[k] + "\n"
|
||||
signedHeaders += lowerKey + ";"
|
||||
}
|
||||
}
|
||||
signedHeaders = strings.TrimSuffix(signedHeaders, ";")
|
||||
|
||||
canonicalRequest := req.httpMethod + "\n" + req.canonicalUri + "\n" + canonicalQueryString + "\n" + canonicalHeaders + "\n" + signedHeaders + "\n" + hashedRequestPayload
|
||||
|
||||
hashedCanonicalRequest := sha256Hex([]byte(canonicalRequest))
|
||||
stringToSign := ALGORITHM + "\n" + hashedCanonicalRequest
|
||||
|
||||
byteData, err := hmac256([]byte(AccessKeySecret), stringToSign)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
panic(err)
|
||||
}
|
||||
signature := strings.ToLower(hex.EncodeToString(byteData))
|
||||
|
||||
authorization := ALGORITHM + " Credential=" + AccessKeyId + ",SignedHeaders=" + signedHeaders + ",Signature=" + signature
|
||||
req.headers["Authorization"] = authorization
|
||||
}
|
||||
|
||||
func hmac256(key []byte, toSignString string) ([]byte, error) {
|
||||
h := hmac.New(sha256.New, key)
|
||||
_, err := h.Write([]byte(toSignString))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h.Sum(nil), nil
|
||||
}
|
||||
|
||||
func sha256Hex(byteArray []byte) string {
|
||||
hash := sha256.New()
|
||||
_, _ = hash.Write(byteArray)
|
||||
hexString := hex.EncodeToString(hash.Sum(nil))
|
||||
return hexString
|
||||
}
|
||||
|
||||
func percentCode(str string) string {
|
||||
str = strings.ReplaceAll(str, "+", "%20")
|
||||
str = strings.ReplaceAll(str, "*", "%2A")
|
||||
str = strings.ReplaceAll(str, "%7E", "~")
|
||||
return str
|
||||
}
|
||||
|
||||
func formDataToString(formData map[string]interface{}) *string {
|
||||
tmp := make(map[string]interface{})
|
||||
processObject(tmp, "", formData)
|
||||
res := ""
|
||||
urlEncoder := url.Values{}
|
||||
for key, value := range tmp {
|
||||
v := fmt.Sprintf("%v", value)
|
||||
urlEncoder.Add(key, v)
|
||||
}
|
||||
res = urlEncoder.Encode()
|
||||
return &res
|
||||
}
|
||||
|
||||
// processObject 递归处理对象,将复杂对象(如Map和List)展开为平面的键值对
|
||||
func processObject(mapResult map[string]interface{}, key string, value interface{}) {
|
||||
if value == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case []interface{}:
|
||||
for i, item := range v {
|
||||
processObject(mapResult, fmt.Sprintf("%s.%d", key, i+1), item)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
for subKey, subValue := range v {
|
||||
processObject(mapResult, fmt.Sprintf("%s.%s", key, subKey), subValue)
|
||||
}
|
||||
default:
|
||||
if strings.HasPrefix(key, ".") {
|
||||
key = key[1:]
|
||||
}
|
||||
if b, ok := v.([]byte); ok {
|
||||
mapResult[key] = string(b)
|
||||
} else {
|
||||
mapResult[key] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateRequestForText(config cfg.AISecurityConfig, checkAction, checkService, text, sessionID string) (path string, headers [][2]string, reqBody []byte) {
|
||||
httpMethod := "POST"
|
||||
canonicalUri := "/"
|
||||
xAcsVersion := "2022-03-02"
|
||||
req := newRequest(httpMethod, canonicalUri, config.Host, checkAction, xAcsVersion)
|
||||
|
||||
req.queryParam["Service"] = checkService
|
||||
|
||||
body := make(map[string]interface{})
|
||||
serviceParameters := make(map[string]interface{})
|
||||
serviceParameters["content"] = text
|
||||
serviceParameters["sessionId"] = sessionID
|
||||
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
|
||||
serviceParametersJSON, _ := json.Marshal(serviceParameters)
|
||||
body["ServiceParameters"] = serviceParametersJSON
|
||||
str := formDataToString(body)
|
||||
req.body = []byte(*str)
|
||||
req.headers["content-type"] = "application/x-www-form-urlencoded"
|
||||
req.headers["User-Agent"] = cfg.AliyunUserAgent
|
||||
|
||||
getAuthorization(req, config.AK, config.SK, config.Token)
|
||||
|
||||
q := url.Values{}
|
||||
keys := maps.Keys(req.queryParam)
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
v := req.queryParam[k]
|
||||
q.Set(k, fmt.Sprintf("%v", v))
|
||||
}
|
||||
for k, v := range req.headers {
|
||||
if k != "host" {
|
||||
headers = append(headers, [2]string{k, v})
|
||||
}
|
||||
}
|
||||
return "?" + q.Encode(), headers, req.body
|
||||
}
|
||||
|
||||
func GenerateRequestForImage(config cfg.AISecurityConfig, checkAction, checkService, imgUrl, imgBase64 string) (path string, headers [][2]string, reqBody []byte) {
|
||||
httpMethod := "POST"
|
||||
canonicalUri := "/"
|
||||
xAcsVersion := "2022-03-02"
|
||||
req := newRequest(httpMethod, canonicalUri, config.Host, checkAction, xAcsVersion)
|
||||
|
||||
req.queryParam["Service"] = checkService
|
||||
|
||||
body := make(map[string]interface{})
|
||||
serviceParameters := make(map[string]interface{})
|
||||
if imgUrl != "" {
|
||||
serviceParameters["imageUrls"] = []string{imgUrl}
|
||||
}
|
||||
serviceParametersJSON, _ := json.Marshal(serviceParameters)
|
||||
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
|
||||
body["ServiceParameters"] = serviceParametersJSON
|
||||
if imgBase64 != "" {
|
||||
body["ImageBase64Str"] = imgBase64
|
||||
}
|
||||
str := formDataToString(body)
|
||||
req.body = []byte(*str)
|
||||
req.headers["content-type"] = "application/x-www-form-urlencoded"
|
||||
req.headers["User-Agent"] = cfg.AliyunUserAgent
|
||||
|
||||
getAuthorization(req, config.AK, config.SK, config.Token)
|
||||
|
||||
q := url.Values{}
|
||||
keys := maps.Keys(req.queryParam)
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
v := req.queryParam[k]
|
||||
q.Set(k, fmt.Sprintf("%v", v))
|
||||
}
|
||||
for k, v := range req.headers {
|
||||
// host will be added by envoy automatically
|
||||
if k != "host" {
|
||||
headers = append(headers, [2]string{k, v})
|
||||
}
|
||||
}
|
||||
return "?" + q.Encode(), headers, req.body
|
||||
}
|
||||
@@ -0,0 +1,249 @@
|
||||
package text
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
ctx.SetContext("end_of_stream_received", false)
|
||||
ctx.SetContext("during_call", false)
|
||||
ctx.SetContext("risk_detected", false)
|
||||
sessionID, _ := utils.GenerateHexID(20)
|
||||
ctx.SetContext("sessionID", sessionID)
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
ctx.NeedPauseStreamingResponse()
|
||||
return types.ActionContinue
|
||||
} else {
|
||||
ctx.BufferResponseBody()
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
}
|
||||
|
||||
func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
var sessionID string
|
||||
if ctx.GetContext("sessionID") == nil {
|
||||
sessionID, _ = utils.GenerateHexID(20)
|
||||
ctx.SetContext("sessionID", sessionID)
|
||||
} else {
|
||||
sessionID, _ = ctx.GetContext("sessionID").(string)
|
||||
}
|
||||
var bufferQueue [][]byte
|
||||
var singleCall func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
if ctx.GetContext("end_of_stream_received").(bool) {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
ctx.SetContext("during_call", false)
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Error("failed to unmarshal aliyun content security response at response phase")
|
||||
if ctx.GetContext("end_of_stream_received").(bool) {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
ctx.SetContext("during_call", false)
|
||||
return
|
||||
}
|
||||
if !cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = "\n" + response.Data.Advice[0].Answer
|
||||
} else if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.InjectEncodedDataToFilterChain(jsonData, true)
|
||||
return
|
||||
}
|
||||
endStream := ctx.GetContext("end_of_stream_received").(bool) && ctx.BufferQueueSize() == 0
|
||||
proxywasm.InjectEncodedDataToFilterChain(bytes.Join(bufferQueue, []byte("")), endStream)
|
||||
bufferQueue = [][]byte{}
|
||||
if !endStream {
|
||||
ctx.SetContext("during_call", false)
|
||||
singleCall()
|
||||
}
|
||||
}
|
||||
singleCall = func() {
|
||||
if ctx.GetContext("during_call").(bool) {
|
||||
return
|
||||
}
|
||||
if ctx.BufferQueueSize() >= config.BufferLimit || ctx.GetContext("end_of_stream_received").(bool) {
|
||||
var buffer string
|
||||
for ctx.BufferQueueSize() > 0 {
|
||||
front := ctx.PopBuffer()
|
||||
bufferQueue = append(bufferQueue, front)
|
||||
msg := gjson.GetBytes(front, config.ResponseStreamContentJsonPath).String()
|
||||
buffer += msg
|
||||
if len([]rune(buffer)) >= config.BufferLimit {
|
||||
break
|
||||
}
|
||||
}
|
||||
// if streaming body has reasoning_content, buffer maybe empty
|
||||
log.Debugf("current content piece: %s", buffer)
|
||||
if len(buffer) == 0 {
|
||||
return
|
||||
}
|
||||
ctx.SetContext("during_call", true)
|
||||
log.Debugf("current content piece: %s", buffer)
|
||||
checkService := config.GetResponseCheckService(consumer)
|
||||
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, buffer, sessionID)
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
if ctx.GetContext("end_of_stream_received").(bool) {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !ctx.GetContext("risk_detected").(bool) {
|
||||
unifiedChunk := wrapper.UnifySSEChunk(data)
|
||||
hasTrailingSeparator := bytes.HasSuffix(unifiedChunk, []byte("\n\n"))
|
||||
trimmedChunk := bytes.TrimSpace(unifiedChunk)
|
||||
chunks := bytes.Split(trimmedChunk, []byte("\n\n"))
|
||||
// Filter out empty chunks
|
||||
nonEmptyChunks := make([][]byte, 0, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
if len(chunk) > 0 {
|
||||
nonEmptyChunks = append(nonEmptyChunks, chunk)
|
||||
}
|
||||
}
|
||||
// Restore separators
|
||||
for i := range len(nonEmptyChunks) - 1 {
|
||||
nonEmptyChunks[i] = append(nonEmptyChunks[i], []byte("\n\n")...)
|
||||
}
|
||||
if hasTrailingSeparator && len(nonEmptyChunks) > 0 {
|
||||
nonEmptyChunks[len(nonEmptyChunks)-1] = append(nonEmptyChunks[len(nonEmptyChunks)-1], []byte("\n\n")...)
|
||||
}
|
||||
for _, chunk := range nonEmptyChunks {
|
||||
ctx.PushBuffer(chunk)
|
||||
}
|
||||
// for _, chunk := range bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) {
|
||||
// ctx.PushBuffer([]byte(string(chunk) + "\n\n"))
|
||||
// }
|
||||
ctx.SetContext("end_of_stream_received", endOfStream)
|
||||
if !ctx.GetContext("during_call").(bool) {
|
||||
singleCall()
|
||||
}
|
||||
} else if endOfStream {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
log.Debugf("checking response body...")
|
||||
startTime := time.Now().UnixMilli()
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
isStreamingResponse := strings.Contains(contentType, "event-stream")
|
||||
var content string
|
||||
if isStreamingResponse {
|
||||
content = utils.ExtractMessageFromStreamingBody(body, config.ResponseStreamContentJsonPath)
|
||||
} else {
|
||||
content = gjson.GetBytes(body, config.ResponseContentJsonPath).String()
|
||||
}
|
||||
log.Debugf("Raw response content is: %s", content)
|
||||
if len(content) == 0 {
|
||||
log.Info("response content is empty. skip")
|
||||
return types.ActionContinue
|
||||
}
|
||||
contentIndex := 0
|
||||
sessionID, _ := utils.GenerateHexID(20)
|
||||
var singleCall func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Error("failed to unmarshal aliyun content security response at response phase")
|
||||
proxywasm.ResumeHttpResponse()
|
||||
return
|
||||
}
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if contentIndex >= len(content) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "response pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
if config.ProtocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
} else if isStreamingResponse {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
config.IncrementCounter("ai_sec_response_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "response deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
var nextContentIndex int
|
||||
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||
nextContentIndex = len(content)
|
||||
} else {
|
||||
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||
}
|
||||
contentPiece := content[contentIndex:nextContentIndex]
|
||||
contentIndex = nextContentIndex
|
||||
log.Debugf("current content piece: %s", contentPiece)
|
||||
checkService := config.GetResponseCheckService(consumer)
|
||||
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, contentPiece, sessionID)
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
}
|
||||
singleCall()
|
||||
return types.ActionPause
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package multi_modal_guard
|
||||
|
||||
import (
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
common_text "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
func OnHttpRequestHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
return text.HandleTextGenerationRequestBody(ctx, config, body)
|
||||
}
|
||||
|
||||
func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||
switch config.ApiType {
|
||||
case cfg.ApiTextGeneration:
|
||||
return common_text.HandleTextGenerationResponseHeader(ctx, config)
|
||||
case cfg.ApiImageGeneration:
|
||||
switch config.ProviderType {
|
||||
case cfg.ProviderOpenAI, cfg.ProviderQwen:
|
||||
return image.HandleImageGenerationResponseHeader(ctx, config)
|
||||
default:
|
||||
log.Errorf("[on response header] image generation api don't support provider: %s", config.ProviderType)
|
||||
return types.ActionContinue
|
||||
}
|
||||
default:
|
||||
log.Errorf("[on response header] multi_modal_guard don't support api: %s", config.ApiType)
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
|
||||
func OnHttpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
|
||||
switch config.ApiType {
|
||||
case cfg.ApiTextGeneration:
|
||||
return common_text.HandleTextGenerationStreamingResponseBody(ctx, config, data, endOfStream)
|
||||
default:
|
||||
log.Errorf("[on streaming response body] multi_modal_guard don't support api: %s", config.ApiType)
|
||||
return data
|
||||
}
|
||||
}
|
||||
|
||||
func OnHttpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
switch config.ApiType {
|
||||
case cfg.ApiTextGeneration:
|
||||
return common_text.HandleTextGenerationResponseBody(ctx, config, body)
|
||||
case cfg.ApiImageGeneration:
|
||||
switch config.ProviderType {
|
||||
case cfg.ProviderOpenAI:
|
||||
return image.HandleOpenAIImageGenerationResponseBody(ctx, config, body)
|
||||
case cfg.ProviderQwen:
|
||||
return image.HandleQwenImageGenerationResponseBody(ctx, config, body)
|
||||
default:
|
||||
log.Errorf("[on response body] image generation api don't support provider: %s", config.ProviderType)
|
||||
return types.ActionContinue
|
||||
}
|
||||
default:
|
||||
log.Errorf("[on response body] multi_modal_guard don't support api: %s", config.ApiType)
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package image
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
func HandleImageGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
ctx.SetContext("risk_detected", false)
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
} else {
|
||||
ctx.BufferResponseBody()
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type ImageItemForOpenAI struct {
|
||||
Content string
|
||||
Type string // URL or BASE64
|
||||
}
|
||||
|
||||
func getOpenAIImageResults(body []byte) []ImageItemForOpenAI {
|
||||
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
|
||||
result := []ImageItemForOpenAI{}
|
||||
for _, part := range gjson.GetBytes(body, "data").Array() {
|
||||
if url := part.Get("url").String(); url != "" {
|
||||
result = append(result, ImageItemForOpenAI{
|
||||
Content: url,
|
||||
Type: "URL",
|
||||
})
|
||||
}
|
||||
if b64 := part.Get("b64_json").String(); b64 != "" {
|
||||
result = append(result, ImageItemForOpenAI{
|
||||
Content: b64,
|
||||
Type: "BASE64",
|
||||
})
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
log.Debugf("checking response body...")
|
||||
checkImageService := config.GetResponseImageCheckService(consumer)
|
||||
startTime := time.Now().UnixMilli()
|
||||
imgResults := getOpenAIImageResults(body)
|
||||
if len(imgResults) == 0 {
|
||||
return types.ActionContinue
|
||||
}
|
||||
imageIndex := 0
|
||||
var singleCall func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
imageIndex += 1
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
if imageIndex < len(imgResults) {
|
||||
singleCall()
|
||||
} else {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
if imageIndex < len(imgResults) {
|
||||
singleCall()
|
||||
} else {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
return
|
||||
}
|
||||
endTime := time.Now().UnixMilli()
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if imageIndex >= len(imgResults) {
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte("illegal image"), -1)
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
img := imgResults[imageIndex]
|
||||
imgUrl := ""
|
||||
imgBase64 := ""
|
||||
if img.Type == "BASE64" {
|
||||
imgBase64 = img.Content
|
||||
} else {
|
||||
imgUrl = img.Content
|
||||
}
|
||||
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
}
|
||||
singleCall()
|
||||
return types.ActionPause
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func getQwenImageUrls(body []byte) []string {
|
||||
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
|
||||
result := []string{}
|
||||
// 文生图/文生图v1/文生图v2/通用图像编辑2.5/通用图像编辑2.1/涂鸦作画/图像局部重绘/人像风格重绘
|
||||
// 虚拟模特/图像背景生成/人物写真FaceChain/文生图StableDiffusion/文生图FLUX/文字纹理生成API
|
||||
for _, part := range gjson.GetBytes(body, "output.results").Array() {
|
||||
if url := part.Get("url").String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
}
|
||||
// 图像编辑
|
||||
for _, part := range gjson.GetBytes(body, "output.choices.0.message.content").Array() {
|
||||
if url := part.Get("image").String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
}
|
||||
// 图像翻译/AI试衣OutfitAnyone
|
||||
if url := gjson.GetBytes(body, "output.image_url").String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
// 图像画面扩展/(part of)人物实例分割/图像擦除补全
|
||||
if url := gjson.GetBytes(body, "output.output_image_url").String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
// 鞋靴模特
|
||||
if url := gjson.GetBytes(body, "output.result_url").String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
// 创意海报生成
|
||||
for _, part := range gjson.GetBytes(body, "output.render_urls").Array() {
|
||||
if url := part.String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
}
|
||||
for _, part := range gjson.GetBytes(body, "output.bg_urls").Array() {
|
||||
if url := part.String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
}
|
||||
// 人物实例分割
|
||||
if url := gjson.GetBytes(body, "output.output_vis_image_url").String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
// 文字变形API
|
||||
for _, part := range gjson.GetBytes(body, "output.results").Array() {
|
||||
if url := part.Get("png_url").String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
if url := part.Get("svg_url").String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
log.Debugf("checking response body...")
|
||||
checkImageService := config.GetResponseImageCheckService(consumer)
|
||||
startTime := time.Now().UnixMilli()
|
||||
imgUrls := getQwenImageUrls(body)
|
||||
if len(imgUrls) == 0 {
|
||||
return types.ActionContinue
|
||||
}
|
||||
imageIndex := 0
|
||||
var singleCall func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
imageIndex += 1
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
if imageIndex < len(imgUrls) {
|
||||
singleCall()
|
||||
} else {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
if imageIndex < len(imgUrls) {
|
||||
singleCall()
|
||||
} else {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
return
|
||||
}
|
||||
endTime := time.Now().UnixMilli()
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if imageIndex >= len(imgUrls) {
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte("illegal image"), -1)
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
imgUrl := imgUrls[imageIndex]
|
||||
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, "")
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
}
|
||||
singleCall()
|
||||
return types.ActionPause
|
||||
}
|
||||
@@ -0,0 +1,231 @@
|
||||
package text
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type ImageItemForOpenAI struct {
|
||||
Content string
|
||||
Type string // URL or BASE64
|
||||
}
|
||||
|
||||
func parseContent(json gjson.Result) (text string, images []ImageItemForOpenAI) {
|
||||
images = []ImageItemForOpenAI{}
|
||||
if json.IsArray() {
|
||||
for _, item := range json.Array() {
|
||||
switch item.Get("type").String() {
|
||||
case "text":
|
||||
text += item.Get("text").String()
|
||||
case "image_url":
|
||||
imgContent := item.Get("image_url.url").String()
|
||||
if strings.HasPrefix(imgContent, "data:image") {
|
||||
images = append(images, ImageItemForOpenAI{
|
||||
Content: imgContent,
|
||||
Type: "BASE64",
|
||||
})
|
||||
} else {
|
||||
images = append(images, ImageItemForOpenAI{
|
||||
Content: imgContent,
|
||||
Type: "URL",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
text = json.String()
|
||||
}
|
||||
return text, images
|
||||
}
|
||||
|
||||
func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
checkService := config.GetRequestCheckService(consumer)
|
||||
checkImageService := config.GetRequestImageCheckService(consumer)
|
||||
startTime := time.Now().UnixMilli()
|
||||
// content := gjson.GetBytes(body, config.RequestContentJsonPath).String()
|
||||
content, images := parseContent(gjson.GetBytes(body, config.RequestContentJsonPath))
|
||||
log.Debugf("Raw request content is: %s", content)
|
||||
if len(content) == 0 && len(images) == 0 {
|
||||
log.Info("request content is empty. skip")
|
||||
return types.ActionContinue
|
||||
}
|
||||
contentIndex := 0
|
||||
imageIndex := 0
|
||||
sessionID, _ := utils.GenerateHexID(20)
|
||||
var singleCall func()
|
||||
var singleCallForImage func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if contentIndex >= len(content) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
if len(images) > 0 && config.CheckRequestImage {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
if config.ProtocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
var nextContentIndex int
|
||||
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||
nextContentIndex = len(content)
|
||||
} else {
|
||||
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||
}
|
||||
contentPiece := content[contentIndex:nextContentIndex]
|
||||
contentIndex = nextContentIndex
|
||||
log.Debugf("current content piece: %s", contentPiece)
|
||||
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}
|
||||
|
||||
callbackForImage := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
imageIndex += 1
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
if imageIndex < len(images) {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
if imageIndex < len(images) {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
return
|
||||
}
|
||||
endTime := time.Now().UnixMilli()
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if imageIndex >= len(images) {
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
} else {
|
||||
singleCallForImage()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
if config.ProtocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCallForImage = func() {
|
||||
img := images[imageIndex]
|
||||
imgUrl := ""
|
||||
imgBase64 := ""
|
||||
if img.Type == "BASE64" {
|
||||
imgBase64 = img.Content
|
||||
} else {
|
||||
imgUrl = img.Content
|
||||
}
|
||||
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
|
||||
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}
|
||||
if len(content) > 0 {
|
||||
singleCall()
|
||||
} else {
|
||||
singleCallForImage()
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package text_moderation_plus
|
||||
|
||||
import (
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
common_text "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus/text"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
func OnHttpRequestHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
return text.HandleTextGenerationRequestBody(ctx, config, body)
|
||||
}
|
||||
|
||||
func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||
switch config.ApiType {
|
||||
case cfg.ApiTextGeneration:
|
||||
return common_text.HandleTextGenerationResponseHeader(ctx, config)
|
||||
default:
|
||||
log.Errorf("text_moderation_plus don't support api: %s", config.ApiType)
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
|
||||
func OnHttpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
|
||||
switch config.ApiType {
|
||||
case cfg.ApiTextGeneration:
|
||||
return common_text.HandleTextGenerationStreamingResponseBody(ctx, config, data, endOfStream)
|
||||
default:
|
||||
log.Errorf("text_moderation_plus don't support api: %s", config.ApiType)
|
||||
return data
|
||||
}
|
||||
}
|
||||
|
||||
func OnHttpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
switch config.ApiType {
|
||||
case cfg.ApiTextGeneration:
|
||||
return common_text.HandleTextGenerationResponseBody(ctx, config, body)
|
||||
default:
|
||||
log.Errorf("text_moderation_plus don't support api: %s", config.ApiType)
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package text
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
startTime := time.Now().UnixMilli()
|
||||
content := gjson.GetBytes(body, config.RequestContentJsonPath).String()
|
||||
log.Debugf("Raw request content is: %s", content)
|
||||
if len(content) == 0 {
|
||||
log.Info("request content is empty. skip")
|
||||
return types.ActionContinue
|
||||
}
|
||||
contentIndex := 0
|
||||
sessionID, _ := utils.GenerateHexID(20)
|
||||
var singleCall func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Error("failed to unmarshal aliyun content security response at request phase")
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if contentIndex >= len(content) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
if config.ProtocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
var nextContentIndex int
|
||||
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||
nextContentIndex = len(content)
|
||||
} else {
|
||||
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||
}
|
||||
contentPiece := content[contentIndex:nextContentIndex]
|
||||
contentIndex = nextContentIndex
|
||||
checkService := config.GetRequestCheckService(consumer)
|
||||
path, headers, body := common.GenerateRequestForText(config, cfg.TextModerationPlus, checkService, contentPiece, sessionID)
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}
|
||||
singleCall()
|
||||
return types.ActionPause
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -18,6 +18,8 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -143,16 +145,16 @@ func TestParseConfig(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
securityConfig := config.(*AISecurityConfig)
|
||||
require.Equal(t, "test-ak", securityConfig.ak)
|
||||
require.Equal(t, "test-sk", securityConfig.sk)
|
||||
require.Equal(t, true, securityConfig.checkRequest)
|
||||
require.Equal(t, true, securityConfig.checkResponse)
|
||||
require.Equal(t, "high", securityConfig.contentModerationLevelBar)
|
||||
require.Equal(t, "high", securityConfig.promptAttackLevelBar)
|
||||
require.Equal(t, "S3", securityConfig.sensitiveDataLevelBar)
|
||||
require.Equal(t, uint32(2000), securityConfig.timeout)
|
||||
require.Equal(t, 1000, securityConfig.bufferLimit)
|
||||
securityConfig := config.(*cfg.AISecurityConfig)
|
||||
require.Equal(t, "test-ak", securityConfig.AK)
|
||||
require.Equal(t, "test-sk", securityConfig.SK)
|
||||
require.Equal(t, true, securityConfig.CheckRequest)
|
||||
require.Equal(t, true, securityConfig.CheckResponse)
|
||||
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
|
||||
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
|
||||
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
|
||||
require.Equal(t, uint32(2000), securityConfig.Timeout)
|
||||
require.Equal(t, 1000, securityConfig.BufferLimit)
|
||||
})
|
||||
|
||||
// 测试仅检查请求的配置
|
||||
@@ -164,12 +166,12 @@ func TestParseConfig(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
securityConfig := config.(*AISecurityConfig)
|
||||
require.Equal(t, true, securityConfig.checkRequest)
|
||||
require.Equal(t, false, securityConfig.checkResponse)
|
||||
require.Equal(t, "high", securityConfig.contentModerationLevelBar)
|
||||
require.Equal(t, "high", securityConfig.promptAttackLevelBar)
|
||||
require.Equal(t, "S3", securityConfig.sensitiveDataLevelBar)
|
||||
securityConfig := config.(*cfg.AISecurityConfig)
|
||||
require.Equal(t, true, securityConfig.CheckRequest)
|
||||
require.Equal(t, false, securityConfig.CheckResponse)
|
||||
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
|
||||
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
|
||||
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
|
||||
})
|
||||
|
||||
// 测试缺少必需字段的配置
|
||||
@@ -202,13 +204,13 @@ func TestParseConfig(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
securityConfig := config.(*AISecurityConfig)
|
||||
require.Equal(t, "llm_query_moderation", securityConfig.getRequestCheckService("aaaa"))
|
||||
require.Equal(t, "llm_query_moderation_1", securityConfig.getRequestCheckService("aaa"))
|
||||
require.Equal(t, "llm_response_moderation", securityConfig.getResponseCheckService("bb"))
|
||||
require.Equal(t, "llm_response_moderation_1", securityConfig.getResponseCheckService("bbb-prefix-test"))
|
||||
require.Equal(t, "high", securityConfig.getMaliciousUrlLevelBar("cc"))
|
||||
require.Equal(t, "low", securityConfig.getMaliciousUrlLevelBar("ccc-regexp-test"))
|
||||
securityConfig := config.(*cfg.AISecurityConfig)
|
||||
require.Equal(t, "llm_query_moderation", securityConfig.GetRequestCheckService("aaaa"))
|
||||
require.Equal(t, "llm_query_moderation_1", securityConfig.GetRequestCheckService("aaa"))
|
||||
require.Equal(t, "llm_response_moderation", securityConfig.GetResponseCheckService("bb"))
|
||||
require.Equal(t, "llm_response_moderation_1", securityConfig.GetResponseCheckService("bbb-prefix-test"))
|
||||
require.Equal(t, "high", securityConfig.GetMaliciousUrlLevelBar("cc"))
|
||||
require.Equal(t, "low", securityConfig.GetMaliciousUrlLevelBar("ccc-regexp-test"))
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -385,62 +387,27 @@ func TestOnHttpResponseHeaders(t *testing.T) {
|
||||
func TestRiskLevelFunctions(t *testing.T) {
|
||||
// 测试风险等级转换函数
|
||||
t.Run("risk level conversion", func(t *testing.T) {
|
||||
require.Equal(t, 4, levelToInt(MaxRisk))
|
||||
require.Equal(t, 3, levelToInt(HighRisk))
|
||||
require.Equal(t, 2, levelToInt(MediumRisk))
|
||||
require.Equal(t, 1, levelToInt(LowRisk))
|
||||
require.Equal(t, 0, levelToInt(NoRisk))
|
||||
require.Equal(t, -1, levelToInt("invalid"))
|
||||
require.Equal(t, 4, cfg.LevelToInt(cfg.MaxRisk))
|
||||
require.Equal(t, 3, cfg.LevelToInt(cfg.HighRisk))
|
||||
require.Equal(t, 2, cfg.LevelToInt(cfg.MediumRisk))
|
||||
require.Equal(t, 1, cfg.LevelToInt(cfg.LowRisk))
|
||||
require.Equal(t, 0, cfg.LevelToInt(cfg.NoRisk))
|
||||
require.Equal(t, -1, cfg.LevelToInt("invalid"))
|
||||
})
|
||||
|
||||
// 测试风险等级比较
|
||||
t.Run("risk level comparison", func(t *testing.T) {
|
||||
require.True(t, levelToInt(HighRisk) >= levelToInt(MediumRisk))
|
||||
require.True(t, levelToInt(MediumRisk) >= levelToInt(LowRisk))
|
||||
require.True(t, levelToInt(LowRisk) >= levelToInt(NoRisk))
|
||||
require.False(t, levelToInt(LowRisk) >= levelToInt(HighRisk))
|
||||
require.True(t, cfg.LevelToInt(cfg.HighRisk) >= cfg.LevelToInt(cfg.MediumRisk))
|
||||
require.True(t, cfg.LevelToInt(cfg.MediumRisk) >= cfg.LevelToInt(cfg.LowRisk))
|
||||
require.True(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.NoRisk))
|
||||
require.False(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.HighRisk))
|
||||
})
|
||||
}
|
||||
|
||||
func TestUtilityFunctions(t *testing.T) {
|
||||
// 测试URL编码函数
|
||||
t.Run("url encoding", func(t *testing.T) {
|
||||
original := "test+string:with=special&chars@$"
|
||||
encoded := urlEncoding(original)
|
||||
require.NotEqual(t, original, encoded)
|
||||
require.Contains(t, encoded, "%2B") // + 应该被编码
|
||||
require.Contains(t, encoded, "%3A") // : 应该被编码
|
||||
require.Contains(t, encoded, "%3D") // = 应该被编码
|
||||
require.Contains(t, encoded, "%26") // & 应该被编码
|
||||
})
|
||||
|
||||
// 测试HMAC-SHA1签名函数
|
||||
t.Run("hmac sha1", func(t *testing.T) {
|
||||
message := "test message"
|
||||
secret := "test secret"
|
||||
signature := hmacSha1(message, secret)
|
||||
require.NotEmpty(t, signature)
|
||||
require.NotEqual(t, message, signature)
|
||||
})
|
||||
|
||||
// 测试签名生成函数
|
||||
t.Run("signature generation", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
params := map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
}
|
||||
secret := "test-secret"
|
||||
signature := getSign(params, secret)
|
||||
require.NotEmpty(t, signature)
|
||||
})
|
||||
|
||||
// 测试十六进制ID生成函数
|
||||
t.Run("hex id generation", func(t *testing.T) {
|
||||
id, err := generateHexID(16)
|
||||
id, err := utils.GenerateHexID(16)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, id, 16)
|
||||
require.Regexp(t, "^[0-9a-f]+$", id)
|
||||
@@ -448,7 +415,7 @@ func TestUtilityFunctions(t *testing.T) {
|
||||
|
||||
// 测试随机ID生成函数
|
||||
t.Run("random id generation", func(t *testing.T) {
|
||||
id := generateRandomID()
|
||||
id := utils.GenerateRandomChatID()
|
||||
require.NotEmpty(t, id)
|
||||
require.Contains(t, id, "chatcmpl-")
|
||||
require.Len(t, id, 38) // "chatcmpl-" + 29 random chars
|
||||
|
||||
43
plugins/wasm-go/extensions/ai-security-guard/utils/utils.go
Normal file
43
plugins/wasm-go/extensions/ai-security-guard/utils/utils.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
mrand "math/rand"
|
||||
"strings"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func GenerateHexID(length int) (string, error) {
|
||||
bytes := make([]byte, length/2)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateRandomChatID() string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, 29)
|
||||
for i := range b {
|
||||
b[i] = charset[mrand.Intn(len(charset))]
|
||||
}
|
||||
return "chatcmpl-" + string(b)
|
||||
}
|
||||
|
||||
func ExtractMessageFromStreamingBody(data []byte, jsonPath string) string {
|
||||
chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n"))
|
||||
strChunks := []string{}
|
||||
for _, chunk := range chunks {
|
||||
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
|
||||
strChunks = append(strChunks, gjson.GetBytes(chunk, jsonPath).String())
|
||||
}
|
||||
return strings.Join(strChunks, "")
|
||||
}
|
||||
|
||||
func GetConsumer(ctx wrapper.HttpContext) string {
|
||||
return ctx.GetStringContext("consumer", "")
|
||||
}
|
||||
Reference in New Issue
Block a user