[feat] ai-security-guard refactor & support checking multimoadl input (#3075)

This commit is contained in:
rinfx
2025-12-04 16:33:59 +08:00
committed by jingze
parent 1582fa6ef9
commit 9978db2ac6
15 changed files with 1932 additions and 1014 deletions

View File

@@ -0,0 +1,67 @@
package multi_modal_guard
import (
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
common_text "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
func OnHttpRequestHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
return types.ActionContinue
}
func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
return text.HandleTextGenerationRequestBody(ctx, config, body)
}
func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
switch config.ApiType {
case cfg.ApiTextGeneration:
return common_text.HandleTextGenerationResponseHeader(ctx, config)
case cfg.ApiImageGeneration:
switch config.ProviderType {
case cfg.ProviderOpenAI, cfg.ProviderQwen:
return image.HandleImageGenerationResponseHeader(ctx, config)
default:
log.Errorf("[on response header] image generation api don't support provider: %s", config.ProviderType)
return types.ActionContinue
}
default:
log.Errorf("[on response header] multi_modal_guard don't support api: %s", config.ApiType)
return types.ActionContinue
}
}
func OnHttpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
switch config.ApiType {
case cfg.ApiTextGeneration:
return common_text.HandleTextGenerationStreamingResponseBody(ctx, config, data, endOfStream)
default:
log.Errorf("[on streaming response body] multi_modal_guard don't support api: %s", config.ApiType)
return data
}
}
func OnHttpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
switch config.ApiType {
case cfg.ApiTextGeneration:
return common_text.HandleTextGenerationResponseBody(ctx, config, body)
case cfg.ApiImageGeneration:
switch config.ProviderType {
case cfg.ProviderOpenAI:
return image.HandleOpenAIImageGenerationResponseBody(ctx, config, body)
case cfg.ProviderQwen:
return image.HandleQwenImageGenerationResponseBody(ctx, config, body)
default:
log.Errorf("[on response body] image generation api don't support provider: %s", config.ProviderType)
return types.ActionContinue
}
default:
log.Errorf("[on response body] multi_modal_guard don't support api: %s", config.ApiType)
return types.ActionContinue
}
}

View File

@@ -0,0 +1,22 @@
package image
import (
"strings"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
func HandleImageGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
ctx.SetContext("risk_detected", false)
if strings.Contains(contentType, "text/event-stream") {
ctx.DontReadResponseBody()
return types.ActionContinue
} else {
ctx.BufferResponseBody()
return types.HeaderStopIteration
}
}

View File

@@ -0,0 +1,111 @@
package image
import (
"encoding/json"
"net/http"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
type ImageItemForOpenAI struct {
Content string
Type string // URL or BASE64
}
func getOpenAIImageResults(body []byte) []ImageItemForOpenAI {
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
result := []ImageItemForOpenAI{}
for _, part := range gjson.GetBytes(body, "data").Array() {
if url := part.Get("url").String(); url != "" {
result = append(result, ImageItemForOpenAI{
Content: url,
Type: "URL",
})
}
if b64 := part.Get("b64_json").String(); b64 != "" {
result = append(result, ImageItemForOpenAI{
Content: b64,
Type: "BASE64",
})
}
}
return result
}
func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
log.Debugf("checking response body...")
checkImageService := config.GetResponseImageCheckService(consumer)
startTime := time.Now().UnixMilli()
imgResults := getOpenAIImageResults(body)
if len(imgResults) == 0 {
return types.ActionContinue
}
imageIndex := 0
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
imageIndex += 1
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
if imageIndex < len(imgResults) {
singleCall()
} else {
proxywasm.ResumeHttpResponse()
}
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
if imageIndex < len(imgResults) {
singleCall()
} else {
proxywasm.ResumeHttpResponse()
}
return
}
endTime := time.Now().UnixMilli()
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if imageIndex >= len(imgResults) {
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpResponse()
} else {
singleCall()
}
return
}
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte("illegal image"), -1)
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
img := imgResults[imageIndex]
imgUrl := ""
imgBase64 := ""
if img.Type == "BASE64" {
imgBase64 = img.Content
} else {
imgUrl = img.Content
}
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpResponse()
}
}
singleCall()
return types.ActionPause
}

View File

@@ -0,0 +1,134 @@
package image
import (
"encoding/json"
"net/http"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
func getQwenImageUrls(body []byte) []string {
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
result := []string{}
// 文生图/文生图v1/文生图v2/通用图像编辑2.5/通用图像编辑2.1/涂鸦作画/图像局部重绘/人像风格重绘
// 虚拟模特/图像背景生成/人物写真FaceChain/文生图StableDiffusion/文生图FLUX/文字纹理生成API
for _, part := range gjson.GetBytes(body, "output.results").Array() {
if url := part.Get("url").String(); url != "" {
result = append(result, url)
}
}
// 图像编辑
for _, part := range gjson.GetBytes(body, "output.choices.0.message.content").Array() {
if url := part.Get("image").String(); url != "" {
result = append(result, url)
}
}
// 图像翻译/AI试衣OutfitAnyone
if url := gjson.GetBytes(body, "output.image_url").String(); url != "" {
result = append(result, url)
}
// 图像画面扩展/(part of)人物实例分割/图像擦除补全
if url := gjson.GetBytes(body, "output.output_image_url").String(); url != "" {
result = append(result, url)
}
// 鞋靴模特
if url := gjson.GetBytes(body, "output.result_url").String(); url != "" {
result = append(result, url)
}
// 创意海报生成
for _, part := range gjson.GetBytes(body, "output.render_urls").Array() {
if url := part.String(); url != "" {
result = append(result, url)
}
}
for _, part := range gjson.GetBytes(body, "output.bg_urls").Array() {
if url := part.String(); url != "" {
result = append(result, url)
}
}
// 人物实例分割
if url := gjson.GetBytes(body, "output.output_vis_image_url").String(); url != "" {
result = append(result, url)
}
// 文字变形API
for _, part := range gjson.GetBytes(body, "output.results").Array() {
if url := part.Get("png_url").String(); url != "" {
result = append(result, url)
}
if url := part.Get("svg_url").String(); url != "" {
result = append(result, url)
}
}
return result
}
func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
log.Debugf("checking response body...")
checkImageService := config.GetResponseImageCheckService(consumer)
startTime := time.Now().UnixMilli()
imgUrls := getQwenImageUrls(body)
if len(imgUrls) == 0 {
return types.ActionContinue
}
imageIndex := 0
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
imageIndex += 1
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
if imageIndex < len(imgUrls) {
singleCall()
} else {
proxywasm.ResumeHttpResponse()
}
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
if imageIndex < len(imgUrls) {
singleCall()
} else {
proxywasm.ResumeHttpResponse()
}
return
}
endTime := time.Now().UnixMilli()
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if imageIndex >= len(imgUrls) {
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpResponse()
} else {
singleCall()
}
return
}
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte("illegal image"), -1)
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
imgUrl := imgUrls[imageIndex]
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, "")
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpResponse()
}
}
singleCall()
return types.ActionPause
}

View File

@@ -0,0 +1,231 @@
package text
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
type ImageItemForOpenAI struct {
Content string
Type string // URL or BASE64
}
func parseContent(json gjson.Result) (text string, images []ImageItemForOpenAI) {
images = []ImageItemForOpenAI{}
if json.IsArray() {
for _, item := range json.Array() {
switch item.Get("type").String() {
case "text":
text += item.Get("text").String()
case "image_url":
imgContent := item.Get("image_url.url").String()
if strings.HasPrefix(imgContent, "data:image") {
images = append(images, ImageItemForOpenAI{
Content: imgContent,
Type: "BASE64",
})
} else {
images = append(images, ImageItemForOpenAI{
Content: imgContent,
Type: "URL",
})
}
}
}
} else {
text = json.String()
}
return text, images
}
func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
checkService := config.GetRequestCheckService(consumer)
checkImageService := config.GetRequestImageCheckService(consumer)
startTime := time.Now().UnixMilli()
// content := gjson.GetBytes(body, config.RequestContentJsonPath).String()
content, images := parseContent(gjson.GetBytes(body, config.RequestContentJsonPath))
log.Debugf("Raw request content is: %s", content)
if len(content) == 0 && len(images) == 0 {
log.Info("request content is empty. skip")
return types.ActionContinue
}
contentIndex := 0
imageIndex := 0
sessionID, _ := utils.GenerateHexID(20)
var singleCall func()
var singleCallForImage func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
proxywasm.ResumeHttpRequest()
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
proxywasm.ResumeHttpRequest()
return
}
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if contentIndex >= len(content) {
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
if len(images) > 0 && config.CheckRequestImage {
singleCallForImage()
} else {
proxywasm.ResumeHttpRequest()
}
} else {
singleCall()
}
return
}
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
var nextContentIndex int
if contentIndex+cfg.LengthLimit >= len(content) {
nextContentIndex = len(content)
} else {
nextContentIndex = contentIndex + cfg.LengthLimit
}
contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpRequest()
}
}
callbackForImage := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
imageIndex += 1
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
if imageIndex < len(images) {
singleCallForImage()
} else {
proxywasm.ResumeHttpRequest()
}
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
if imageIndex < len(images) {
singleCallForImage()
} else {
proxywasm.ResumeHttpRequest()
}
return
}
endTime := time.Now().UnixMilli()
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if imageIndex >= len(images) {
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpRequest()
} else {
singleCallForImage()
}
return
}
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCallForImage = func() {
img := images[imageIndex]
imgUrl := ""
imgBase64 := ""
if img.Type == "BASE64" {
imgBase64 = img.Content
} else {
imgUrl = img.Content
}
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpRequest()
}
}
if len(content) > 0 {
singleCall()
} else {
singleCallForImage()
}
return types.ActionPause
}