mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 23:21:08 +08:00
[feat] ai-security-guard support checking prompt and image in request body (#3206)
This commit is contained in:
@@ -219,8 +219,8 @@ func GenerateRequestForImage(config cfg.AISecurityConfig, checkAction, checkServ
|
||||
if imgUrl != "" {
|
||||
serviceParameters["imageUrls"] = []string{imgUrl}
|
||||
}
|
||||
serviceParametersJSON, _ := json.Marshal(serviceParameters)
|
||||
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
|
||||
serviceParametersJSON, _ := json.Marshal(serviceParameters)
|
||||
body["ServiceParameters"] = serviceParametersJSON
|
||||
if imgBase64 != "" {
|
||||
body["ImageBase64Str"] = imgBase64
|
||||
|
||||
@@ -15,7 +15,23 @@ func OnHttpRequestHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig)
|
||||
}
|
||||
|
||||
func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
return text.HandleTextGenerationRequestBody(ctx, config, body)
|
||||
switch config.ApiType {
|
||||
case cfg.ApiTextGeneration:
|
||||
return text.HandleTextGenerationRequestBody(ctx, config, body)
|
||||
case cfg.ApiImageGeneration:
|
||||
switch config.ProviderType {
|
||||
case cfg.ProviderOpenAI:
|
||||
return image.HandleOpenAIImageGenerationRequestBody(ctx, config, body)
|
||||
case cfg.ProviderQwen:
|
||||
return image.HandleQwenImageGenerationRequestBody(ctx, config, body)
|
||||
default:
|
||||
log.Errorf("[on request body] image generation api don't support provider: %s", config.ProviderType)
|
||||
return types.ActionContinue
|
||||
}
|
||||
default:
|
||||
log.Errorf("[on request body] multi_modal_guard don't support api: %s", config.ApiType)
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
|
||||
func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||
|
||||
@@ -9,6 +9,11 @@ import (
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
type ImageItem struct {
|
||||
Content string
|
||||
Type string // URL or BASE64
|
||||
}
|
||||
|
||||
func HandleImageGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
ctx.SetContext("risk_detected", false)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
@@ -14,23 +15,23 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type ImageItemForOpenAI struct {
|
||||
Content string
|
||||
Type string // URL or BASE64
|
||||
func parseOpenAIRequest(body []byte) (text string, images []ImageItem) {
|
||||
text = gjson.GetBytes(body, "prompt").String()
|
||||
return text, images
|
||||
}
|
||||
|
||||
func getOpenAIImageResults(body []byte) []ImageItemForOpenAI {
|
||||
func parseOpenAIResponse(body []byte) []ImageItem {
|
||||
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
|
||||
result := []ImageItemForOpenAI{}
|
||||
result := []ImageItem{}
|
||||
for _, part := range gjson.GetBytes(body, "data").Array() {
|
||||
if url := part.Get("url").String(); url != "" {
|
||||
result = append(result, ImageItemForOpenAI{
|
||||
result = append(result, ImageItem{
|
||||
Content: url,
|
||||
Type: "URL",
|
||||
})
|
||||
}
|
||||
if b64 := part.Get("b64_json").String(); b64 != "" {
|
||||
result = append(result, ImageItemForOpenAI{
|
||||
result = append(result, ImageItem{
|
||||
Content: b64,
|
||||
Type: "BASE64",
|
||||
})
|
||||
@@ -39,12 +40,171 @@ func getOpenAIImageResults(body []byte) []ImageItemForOpenAI {
|
||||
return result
|
||||
}
|
||||
|
||||
func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
checkService := config.GetRequestCheckService(consumer)
|
||||
checkImageService := config.GetRequestImageCheckService(consumer)
|
||||
startTime := time.Now().UnixMilli()
|
||||
content, images := parseOpenAIRequest(body)
|
||||
log.Debugf("Raw request content is: %s", content)
|
||||
if len(content) == 0 && len(images) == 0 {
|
||||
log.Info("request content is empty. skip")
|
||||
return types.ActionContinue
|
||||
}
|
||||
contentIndex := 0
|
||||
imageIndex := 0
|
||||
sessionID, _ := utils.GenerateHexID(20)
|
||||
var singleCall func()
|
||||
var singleCallForImage func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if contentIndex >= len(content) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
if len(images) > 0 && config.CheckRequestImage {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
var nextContentIndex int
|
||||
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||
nextContentIndex = len(content)
|
||||
} else {
|
||||
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||
}
|
||||
contentPiece := content[contentIndex:nextContentIndex]
|
||||
contentIndex = nextContentIndex
|
||||
log.Debugf("current content piece: %s", contentPiece)
|
||||
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}
|
||||
|
||||
callbackForImage := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
imageIndex += 1
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
if imageIndex < len(images) {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
if imageIndex < len(images) {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
return
|
||||
}
|
||||
endTime := time.Now().UnixMilli()
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if imageIndex >= len(images) {
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
} else {
|
||||
singleCallForImage()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCallForImage = func() {
|
||||
img := images[imageIndex]
|
||||
imgUrl := ""
|
||||
imgBase64 := ""
|
||||
if img.Type == "BASE64" {
|
||||
imgBase64 = img.Content
|
||||
} else {
|
||||
imgUrl = img.Content
|
||||
}
|
||||
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
|
||||
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}
|
||||
if len(content) > 0 {
|
||||
singleCall()
|
||||
} else {
|
||||
singleCallForImage()
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
log.Debugf("checking response body...")
|
||||
checkImageService := config.GetResponseImageCheckService(consumer)
|
||||
startTime := time.Now().UnixMilli()
|
||||
imgResults := getOpenAIImageResults(body)
|
||||
imgResults := parseOpenAIResponse(body)
|
||||
if len(imgResults) == 0 {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
@@ -3,10 +3,12 @@ package image
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
@@ -14,7 +16,133 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func getQwenImageUrls(body []byte) []string {
|
||||
func parseImage(body []byte, jsonPath string) *ImageItem {
|
||||
if gjson.GetBytes(body, jsonPath).Exists() {
|
||||
imgContent := gjson.GetBytes(body, jsonPath).String()
|
||||
if strings.HasPrefix(imgContent, "data:image") {
|
||||
return &ImageItem{
|
||||
Content: imgContent,
|
||||
Type: "BASE64",
|
||||
}
|
||||
} else {
|
||||
return &ImageItem{
|
||||
Content: imgContent,
|
||||
Type: "URL",
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseImageArray(body []byte, jsonPath string) []ImageItem {
|
||||
result := []ImageItem{}
|
||||
if gjson.GetBytes(body, jsonPath).Exists() {
|
||||
for _, item := range gjson.GetBytes(body, jsonPath).Array() {
|
||||
imgContent := item.String()
|
||||
if strings.HasPrefix(imgContent, "data:image") {
|
||||
result = append(result, ImageItem{
|
||||
Content: imgContent,
|
||||
Type: "BASE64",
|
||||
})
|
||||
} else {
|
||||
result = append(result, ImageItem{
|
||||
Content: imgContent,
|
||||
Type: "URL",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseQwenRequest(body []byte) (text string, images []ImageItem) {
|
||||
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
|
||||
images = []ImageItem{}
|
||||
// 文生图/文生图v1/文生图v2
|
||||
if gjson.GetBytes(body, "input.prompt").Exists() {
|
||||
text += gjson.GetBytes(body, "input.prompt").String()
|
||||
}
|
||||
// 图像背景生成
|
||||
if gjson.GetBytes(body, "input.ref_prompt").Exists() {
|
||||
text += gjson.GetBytes(body, "input.ref_prompt").String()
|
||||
}
|
||||
if gjson.GetBytes(body, "input.reference_edge.foreground_edge_prompt").Exists() {
|
||||
for _, item := range gjson.GetBytes(body, "input.reference_edge.foreground_edge_prompt").Array() {
|
||||
text += item.String()
|
||||
}
|
||||
}
|
||||
if gjson.GetBytes(body, "input.reference_edge.background_edge_prompt").Exists() {
|
||||
for _, item := range gjson.GetBytes(body, "input.reference_edge.background_edge_prompt").Array() {
|
||||
text += item.String()
|
||||
}
|
||||
}
|
||||
// 创意文字
|
||||
if gjson.GetBytes(body, "input.text").Exists() {
|
||||
text += gjson.GetBytes(body, "input.text").String()
|
||||
}
|
||||
if gjson.GetBytes(body, "input.negative_prompt").Exists() {
|
||||
text += gjson.GetBytes(body, "input.negative_prompt").String()
|
||||
}
|
||||
// 图像编辑
|
||||
if gjson.GetBytes(body, "input.messages.0.content").Exists() {
|
||||
for _, item := range gjson.GetBytes(body, "input.messages.0.content").Array() {
|
||||
if item.Get("text").Exists() {
|
||||
text += item.Get("text").String()
|
||||
} else if item.Get("image").Exists() {
|
||||
imgContent := item.Get("image").String()
|
||||
if strings.HasPrefix(imgContent, "data:image") {
|
||||
images = append(images, ImageItem{
|
||||
Content: imgContent,
|
||||
Type: "BASE64",
|
||||
})
|
||||
} else {
|
||||
images = append(images, ImageItem{
|
||||
Content: imgContent,
|
||||
Type: "URL",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// image json path
|
||||
imageJsonPath := []string{
|
||||
"input.image_url", // 图像翻译/人像风格重绘/图像画面扩展/人物实例分割/图像擦除补全
|
||||
"input.base_image_url", // 通用图像编辑2.1/图像局部重绘/虚拟模特
|
||||
"input.mask_image_url", // 通用图像编辑2.1/图像局部重绘/虚拟模特
|
||||
"input.sketch_image_url", // 涂鸦作画
|
||||
"input.template_image_url", // 鞋靴模特
|
||||
"input.shoe_image_url", // 鞋靴模特
|
||||
"input.base_image_url", // 图像背景生成
|
||||
"input.ref_image_url", // 图像背景生成
|
||||
"input.mask_url", // 图像擦除补全
|
||||
"input.foreground_url", // 图像擦除补全
|
||||
"input.person_image_url", // AI试衣
|
||||
"input.top_garment_url", // AI试衣
|
||||
"input.bottom_garment_url", // AI试衣
|
||||
"input.coarse_image_url", // AI试衣
|
||||
"input.template_url", // 人物写真生成
|
||||
}
|
||||
for _, jsonPath := range imageJsonPath {
|
||||
tmpImage := parseImage(body, jsonPath)
|
||||
if tmpImage != nil {
|
||||
images = append(images, *tmpImage)
|
||||
}
|
||||
}
|
||||
// image array json path
|
||||
imageArrayJsonPath := []string{
|
||||
"input.images", // 通用图像编辑2.5/人物图像检测
|
||||
"input.reference_edge.foreground_edge", // 图像背景生成
|
||||
"input.reference_edge.background_edge", // 图像背景生成
|
||||
"input.user_urls", // 人物写真生成
|
||||
}
|
||||
for _, jsonPath := range imageArrayJsonPath {
|
||||
tmpImageArray := parseImageArray(body, jsonPath)
|
||||
images = append(images, tmpImageArray...)
|
||||
}
|
||||
return text, images
|
||||
}
|
||||
|
||||
func parseQwenResponse(body []byte) []string {
|
||||
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
|
||||
result := []string{}
|
||||
// 文生图/文生图v1/文生图v2/通用图像编辑2.5/通用图像编辑2.1/涂鸦作画/图像局部重绘/人像风格重绘
|
||||
@@ -69,12 +197,172 @@ func getQwenImageUrls(body []byte) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
checkService := config.GetRequestCheckService(consumer)
|
||||
checkImageService := config.GetRequestImageCheckService(consumer)
|
||||
startTime := time.Now().UnixMilli()
|
||||
// content := gjson.GetBytes(body, config.RequestContentJsonPath).String()
|
||||
content, images := parseQwenRequest(body)
|
||||
log.Debugf("Raw request content is: %s", content)
|
||||
if len(content) == 0 && len(images) == 0 {
|
||||
log.Info("request content is empty. skip")
|
||||
return types.ActionContinue
|
||||
}
|
||||
contentIndex := 0
|
||||
imageIndex := 0
|
||||
sessionID, _ := utils.GenerateHexID(20)
|
||||
var singleCall func()
|
||||
var singleCallForImage func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if contentIndex >= len(content) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
if len(images) > 0 && config.CheckRequestImage {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
var nextContentIndex int
|
||||
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||
nextContentIndex = len(content)
|
||||
} else {
|
||||
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||
}
|
||||
contentPiece := content[contentIndex:nextContentIndex]
|
||||
contentIndex = nextContentIndex
|
||||
log.Debugf("current content piece: %s", contentPiece)
|
||||
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}
|
||||
|
||||
callbackForImage := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
imageIndex += 1
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
if imageIndex < len(images) {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
if imageIndex < len(images) {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
return
|
||||
}
|
||||
endTime := time.Now().UnixMilli()
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if imageIndex >= len(images) {
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
} else {
|
||||
singleCallForImage()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCallForImage = func() {
|
||||
img := images[imageIndex]
|
||||
imgUrl := ""
|
||||
imgBase64 := ""
|
||||
if img.Type == "BASE64" {
|
||||
imgBase64 = img.Content
|
||||
} else {
|
||||
imgUrl = img.Content
|
||||
}
|
||||
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
|
||||
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}
|
||||
if len(content) > 0 {
|
||||
singleCall()
|
||||
} else {
|
||||
singleCallForImage()
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
log.Debugf("checking response body...")
|
||||
checkImageService := config.GetResponseImageCheckService(consumer)
|
||||
startTime := time.Now().UnixMilli()
|
||||
imgUrls := getQwenImageUrls(body)
|
||||
imgUrls := parseQwenResponse(body)
|
||||
if len(imgUrls) == 0 {
|
||||
return types.ActionContinue
|
||||
}
|
||||
@@ -114,7 +402,14 @@ func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.A
|
||||
}
|
||||
return
|
||||
}
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte("illegal image"), -1)
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
|
||||
@@ -17,13 +17,13 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type ImageItemForOpenAI struct {
|
||||
type ImageItem struct {
|
||||
Content string
|
||||
Type string // URL or BASE64
|
||||
}
|
||||
|
||||
func parseContent(json gjson.Result) (text string, images []ImageItemForOpenAI) {
|
||||
images = []ImageItemForOpenAI{}
|
||||
func parseContent(json gjson.Result) (text string, images []ImageItem) {
|
||||
images = []ImageItem{}
|
||||
if json.IsArray() {
|
||||
for _, item := range json.Array() {
|
||||
switch item.Get("type").String() {
|
||||
@@ -32,12 +32,12 @@ func parseContent(json gjson.Result) (text string, images []ImageItemForOpenAI)
|
||||
case "image_url":
|
||||
imgContent := item.Get("image_url.url").String()
|
||||
if strings.HasPrefix(imgContent, "data:image") {
|
||||
images = append(images, ImageItemForOpenAI{
|
||||
images = append(images, ImageItem{
|
||||
Content: imgContent,
|
||||
Type: "BASE64",
|
||||
})
|
||||
} else {
|
||||
images = append(images, ImageItemForOpenAI{
|
||||
images = append(images, ImageItem{
|
||||
Content: imgContent,
|
||||
Type: "URL",
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user