mirror of
https://github.com/alibaba/higress.git
synced 2026-06-08 20:27:31 +08:00
[feat] ai-security-guard support checking prompt and image in request body (#3206)
This commit is contained in:
@@ -219,8 +219,8 @@ func GenerateRequestForImage(config cfg.AISecurityConfig, checkAction, checkServ
|
|||||||
if imgUrl != "" {
|
if imgUrl != "" {
|
||||||
serviceParameters["imageUrls"] = []string{imgUrl}
|
serviceParameters["imageUrls"] = []string{imgUrl}
|
||||||
}
|
}
|
||||||
serviceParametersJSON, _ := json.Marshal(serviceParameters)
|
|
||||||
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
|
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
|
||||||
|
serviceParametersJSON, _ := json.Marshal(serviceParameters)
|
||||||
body["ServiceParameters"] = serviceParametersJSON
|
body["ServiceParameters"] = serviceParametersJSON
|
||||||
if imgBase64 != "" {
|
if imgBase64 != "" {
|
||||||
body["ImageBase64Str"] = imgBase64
|
body["ImageBase64Str"] = imgBase64
|
||||||
|
|||||||
@@ -15,7 +15,23 @@ func OnHttpRequestHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||||
return text.HandleTextGenerationRequestBody(ctx, config, body)
|
switch config.ApiType {
|
||||||
|
case cfg.ApiTextGeneration:
|
||||||
|
return text.HandleTextGenerationRequestBody(ctx, config, body)
|
||||||
|
case cfg.ApiImageGeneration:
|
||||||
|
switch config.ProviderType {
|
||||||
|
case cfg.ProviderOpenAI:
|
||||||
|
return image.HandleOpenAIImageGenerationRequestBody(ctx, config, body)
|
||||||
|
case cfg.ProviderQwen:
|
||||||
|
return image.HandleQwenImageGenerationRequestBody(ctx, config, body)
|
||||||
|
default:
|
||||||
|
log.Errorf("[on request body] image generation api don't support provider: %s", config.ProviderType)
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
log.Errorf("[on request body] multi_modal_guard don't support api: %s", config.ApiType)
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||||
|
|||||||
@@ -9,6 +9,11 @@ import (
|
|||||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ImageItem struct {
|
||||||
|
Content string
|
||||||
|
Type string // URL or BASE64
|
||||||
|
}
|
||||||
|
|
||||||
func HandleImageGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
func HandleImageGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||||
ctx.SetContext("risk_detected", false)
|
ctx.SetContext("risk_detected", false)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
"github.com/higress-group/wasm-go/pkg/log"
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
@@ -14,23 +15,23 @@ import (
|
|||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ImageItemForOpenAI struct {
|
func parseOpenAIRequest(body []byte) (text string, images []ImageItem) {
|
||||||
Content string
|
text = gjson.GetBytes(body, "prompt").String()
|
||||||
Type string // URL or BASE64
|
return text, images
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOpenAIImageResults(body []byte) []ImageItemForOpenAI {
|
func parseOpenAIResponse(body []byte) []ImageItem {
|
||||||
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
|
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
|
||||||
result := []ImageItemForOpenAI{}
|
result := []ImageItem{}
|
||||||
for _, part := range gjson.GetBytes(body, "data").Array() {
|
for _, part := range gjson.GetBytes(body, "data").Array() {
|
||||||
if url := part.Get("url").String(); url != "" {
|
if url := part.Get("url").String(); url != "" {
|
||||||
result = append(result, ImageItemForOpenAI{
|
result = append(result, ImageItem{
|
||||||
Content: url,
|
Content: url,
|
||||||
Type: "URL",
|
Type: "URL",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if b64 := part.Get("b64_json").String(); b64 != "" {
|
if b64 := part.Get("b64_json").String(); b64 != "" {
|
||||||
result = append(result, ImageItemForOpenAI{
|
result = append(result, ImageItem{
|
||||||
Content: b64,
|
Content: b64,
|
||||||
Type: "BASE64",
|
Type: "BASE64",
|
||||||
})
|
})
|
||||||
@@ -39,12 +40,171 @@ func getOpenAIImageResults(body []byte) []ImageItemForOpenAI {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||||
|
consumer, _ := ctx.GetContext("consumer").(string)
|
||||||
|
checkService := config.GetRequestCheckService(consumer)
|
||||||
|
checkImageService := config.GetRequestImageCheckService(consumer)
|
||||||
|
startTime := time.Now().UnixMilli()
|
||||||
|
content, images := parseOpenAIRequest(body)
|
||||||
|
log.Debugf("Raw request content is: %s", content)
|
||||||
|
if len(content) == 0 && len(images) == 0 {
|
||||||
|
log.Info("request content is empty. skip")
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
contentIndex := 0
|
||||||
|
imageIndex := 0
|
||||||
|
sessionID, _ := utils.GenerateHexID(20)
|
||||||
|
var singleCall func()
|
||||||
|
var singleCallForImage func()
|
||||||
|
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Info(string(responseBody))
|
||||||
|
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var response cfg.Response
|
||||||
|
err := json.Unmarshal(responseBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("%+v", err)
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||||
|
if contentIndex >= len(content) {
|
||||||
|
endTime := time.Now().UnixMilli()
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
if len(images) > 0 && config.CheckRequestImage {
|
||||||
|
singleCallForImage()
|
||||||
|
} else {
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
singleCall()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
denyMessage := cfg.DefaultDenyMessage
|
||||||
|
if config.DenyMessage != "" {
|
||||||
|
denyMessage = config.DenyMessage
|
||||||
|
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||||
|
denyMessage = response.Data.Advice[0].Answer
|
||||||
|
}
|
||||||
|
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||||
|
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||||
|
ctx.DontReadResponseBody()
|
||||||
|
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||||
|
endTime := time.Now().UnixMilli()
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||||
|
if response.Data.Advice != nil {
|
||||||
|
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||||
|
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||||
|
}
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
}
|
||||||
|
singleCall = func() {
|
||||||
|
var nextContentIndex int
|
||||||
|
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||||
|
nextContentIndex = len(content)
|
||||||
|
} else {
|
||||||
|
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||||
|
}
|
||||||
|
contentPiece := content[contentIndex:nextContentIndex]
|
||||||
|
contentIndex = nextContentIndex
|
||||||
|
log.Debugf("current content piece: %s", contentPiece)
|
||||||
|
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)
|
||||||
|
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed call the safe check service: %v", err)
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
callbackForImage := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
imageIndex += 1
|
||||||
|
log.Info(string(responseBody))
|
||||||
|
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||||
|
if imageIndex < len(images) {
|
||||||
|
singleCallForImage()
|
||||||
|
} else {
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var response cfg.Response
|
||||||
|
err := json.Unmarshal(responseBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("%+v", err)
|
||||||
|
if imageIndex < len(images) {
|
||||||
|
singleCallForImage()
|
||||||
|
} else {
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
endTime := time.Now().UnixMilli()
|
||||||
|
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||||
|
if imageIndex >= len(images) {
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
} else {
|
||||||
|
singleCallForImage()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
denyMessage := cfg.DefaultDenyMessage
|
||||||
|
if config.DenyMessage != "" {
|
||||||
|
denyMessage = config.DenyMessage
|
||||||
|
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||||
|
denyMessage = response.Data.Advice[0].Answer
|
||||||
|
}
|
||||||
|
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||||
|
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||||
|
ctx.DontReadResponseBody()
|
||||||
|
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||||
|
if response.Data.Advice != nil {
|
||||||
|
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||||
|
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||||
|
}
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
}
|
||||||
|
singleCallForImage = func() {
|
||||||
|
img := images[imageIndex]
|
||||||
|
imgUrl := ""
|
||||||
|
imgBase64 := ""
|
||||||
|
if img.Type == "BASE64" {
|
||||||
|
imgBase64 = img.Content
|
||||||
|
} else {
|
||||||
|
imgUrl = img.Content
|
||||||
|
}
|
||||||
|
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
|
||||||
|
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed call the safe check service: %v", err)
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(content) > 0 {
|
||||||
|
singleCall()
|
||||||
|
} else {
|
||||||
|
singleCallForImage()
|
||||||
|
}
|
||||||
|
return types.ActionPause
|
||||||
|
}
|
||||||
|
|
||||||
func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||||
consumer, _ := ctx.GetContext("consumer").(string)
|
consumer, _ := ctx.GetContext("consumer").(string)
|
||||||
log.Debugf("checking response body...")
|
log.Debugf("checking response body...")
|
||||||
checkImageService := config.GetResponseImageCheckService(consumer)
|
checkImageService := config.GetResponseImageCheckService(consumer)
|
||||||
startTime := time.Now().UnixMilli()
|
startTime := time.Now().UnixMilli()
|
||||||
imgResults := getOpenAIImageResults(body)
|
imgResults := parseOpenAIResponse(body)
|
||||||
if len(imgResults) == 0 {
|
if len(imgResults) == 0 {
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ package image
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
"github.com/higress-group/wasm-go/pkg/log"
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
@@ -14,7 +16,133 @@ import (
|
|||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getQwenImageUrls(body []byte) []string {
|
func parseImage(body []byte, jsonPath string) *ImageItem {
|
||||||
|
if gjson.GetBytes(body, jsonPath).Exists() {
|
||||||
|
imgContent := gjson.GetBytes(body, jsonPath).String()
|
||||||
|
if strings.HasPrefix(imgContent, "data:image") {
|
||||||
|
return &ImageItem{
|
||||||
|
Content: imgContent,
|
||||||
|
Type: "BASE64",
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return &ImageItem{
|
||||||
|
Content: imgContent,
|
||||||
|
Type: "URL",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseImageArray(body []byte, jsonPath string) []ImageItem {
|
||||||
|
result := []ImageItem{}
|
||||||
|
if gjson.GetBytes(body, jsonPath).Exists() {
|
||||||
|
for _, item := range gjson.GetBytes(body, jsonPath).Array() {
|
||||||
|
imgContent := item.String()
|
||||||
|
if strings.HasPrefix(imgContent, "data:image") {
|
||||||
|
result = append(result, ImageItem{
|
||||||
|
Content: imgContent,
|
||||||
|
Type: "BASE64",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
result = append(result, ImageItem{
|
||||||
|
Content: imgContent,
|
||||||
|
Type: "URL",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseQwenRequest(body []byte) (text string, images []ImageItem) {
|
||||||
|
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
|
||||||
|
images = []ImageItem{}
|
||||||
|
// 文生图/文生图v1/文生图v2
|
||||||
|
if gjson.GetBytes(body, "input.prompt").Exists() {
|
||||||
|
text += gjson.GetBytes(body, "input.prompt").String()
|
||||||
|
}
|
||||||
|
// 图像背景生成
|
||||||
|
if gjson.GetBytes(body, "input.ref_prompt").Exists() {
|
||||||
|
text += gjson.GetBytes(body, "input.ref_prompt").String()
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(body, "input.reference_edge.foreground_edge_prompt").Exists() {
|
||||||
|
for _, item := range gjson.GetBytes(body, "input.reference_edge.foreground_edge_prompt").Array() {
|
||||||
|
text += item.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(body, "input.reference_edge.background_edge_prompt").Exists() {
|
||||||
|
for _, item := range gjson.GetBytes(body, "input.reference_edge.background_edge_prompt").Array() {
|
||||||
|
text += item.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 创意文字
|
||||||
|
if gjson.GetBytes(body, "input.text").Exists() {
|
||||||
|
text += gjson.GetBytes(body, "input.text").String()
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(body, "input.negative_prompt").Exists() {
|
||||||
|
text += gjson.GetBytes(body, "input.negative_prompt").String()
|
||||||
|
}
|
||||||
|
// 图像编辑
|
||||||
|
if gjson.GetBytes(body, "input.messages.0.content").Exists() {
|
||||||
|
for _, item := range gjson.GetBytes(body, "input.messages.0.content").Array() {
|
||||||
|
if item.Get("text").Exists() {
|
||||||
|
text += item.Get("text").String()
|
||||||
|
} else if item.Get("image").Exists() {
|
||||||
|
imgContent := item.Get("image").String()
|
||||||
|
if strings.HasPrefix(imgContent, "data:image") {
|
||||||
|
images = append(images, ImageItem{
|
||||||
|
Content: imgContent,
|
||||||
|
Type: "BASE64",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
images = append(images, ImageItem{
|
||||||
|
Content: imgContent,
|
||||||
|
Type: "URL",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// image json path
|
||||||
|
imageJsonPath := []string{
|
||||||
|
"input.image_url", // 图像翻译/人像风格重绘/图像画面扩展/人物实例分割/图像擦除补全
|
||||||
|
"input.base_image_url", // 通用图像编辑2.1/图像局部重绘/虚拟模特
|
||||||
|
"input.mask_image_url", // 通用图像编辑2.1/图像局部重绘/虚拟模特
|
||||||
|
"input.sketch_image_url", // 涂鸦作画
|
||||||
|
"input.template_image_url", // 鞋靴模特
|
||||||
|
"input.shoe_image_url", // 鞋靴模特
|
||||||
|
"input.base_image_url", // 图像背景生成
|
||||||
|
"input.ref_image_url", // 图像背景生成
|
||||||
|
"input.mask_url", // 图像擦除补全
|
||||||
|
"input.foreground_url", // 图像擦除补全
|
||||||
|
"input.person_image_url", // AI试衣
|
||||||
|
"input.top_garment_url", // AI试衣
|
||||||
|
"input.bottom_garment_url", // AI试衣
|
||||||
|
"input.coarse_image_url", // AI试衣
|
||||||
|
"input.template_url", // 人物写真生成
|
||||||
|
}
|
||||||
|
for _, jsonPath := range imageJsonPath {
|
||||||
|
tmpImage := parseImage(body, jsonPath)
|
||||||
|
if tmpImage != nil {
|
||||||
|
images = append(images, *tmpImage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// image array json path
|
||||||
|
imageArrayJsonPath := []string{
|
||||||
|
"input.images", // 通用图像编辑2.5/人物图像检测
|
||||||
|
"input.reference_edge.foreground_edge", // 图像背景生成
|
||||||
|
"input.reference_edge.background_edge", // 图像背景生成
|
||||||
|
"input.user_urls", // 人物写真生成
|
||||||
|
}
|
||||||
|
for _, jsonPath := range imageArrayJsonPath {
|
||||||
|
tmpImageArray := parseImageArray(body, jsonPath)
|
||||||
|
images = append(images, tmpImageArray...)
|
||||||
|
}
|
||||||
|
return text, images
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseQwenResponse(body []byte) []string {
|
||||||
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
|
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
|
||||||
result := []string{}
|
result := []string{}
|
||||||
// 文生图/文生图v1/文生图v2/通用图像编辑2.5/通用图像编辑2.1/涂鸦作画/图像局部重绘/人像风格重绘
|
// 文生图/文生图v1/文生图v2/通用图像编辑2.5/通用图像编辑2.1/涂鸦作画/图像局部重绘/人像风格重绘
|
||||||
@@ -69,12 +197,172 @@ func getQwenImageUrls(body []byte) []string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||||
|
consumer, _ := ctx.GetContext("consumer").(string)
|
||||||
|
checkService := config.GetRequestCheckService(consumer)
|
||||||
|
checkImageService := config.GetRequestImageCheckService(consumer)
|
||||||
|
startTime := time.Now().UnixMilli()
|
||||||
|
// content := gjson.GetBytes(body, config.RequestContentJsonPath).String()
|
||||||
|
content, images := parseQwenRequest(body)
|
||||||
|
log.Debugf("Raw request content is: %s", content)
|
||||||
|
if len(content) == 0 && len(images) == 0 {
|
||||||
|
log.Info("request content is empty. skip")
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
contentIndex := 0
|
||||||
|
imageIndex := 0
|
||||||
|
sessionID, _ := utils.GenerateHexID(20)
|
||||||
|
var singleCall func()
|
||||||
|
var singleCallForImage func()
|
||||||
|
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Info(string(responseBody))
|
||||||
|
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var response cfg.Response
|
||||||
|
err := json.Unmarshal(responseBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("%+v", err)
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||||
|
if contentIndex >= len(content) {
|
||||||
|
endTime := time.Now().UnixMilli()
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
if len(images) > 0 && config.CheckRequestImage {
|
||||||
|
singleCallForImage()
|
||||||
|
} else {
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
singleCall()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
denyMessage := cfg.DefaultDenyMessage
|
||||||
|
if config.DenyMessage != "" {
|
||||||
|
denyMessage = config.DenyMessage
|
||||||
|
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||||
|
denyMessage = response.Data.Advice[0].Answer
|
||||||
|
}
|
||||||
|
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||||
|
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||||
|
ctx.DontReadResponseBody()
|
||||||
|
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||||
|
endTime := time.Now().UnixMilli()
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||||
|
if response.Data.Advice != nil {
|
||||||
|
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||||
|
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||||
|
}
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
}
|
||||||
|
singleCall = func() {
|
||||||
|
var nextContentIndex int
|
||||||
|
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||||
|
nextContentIndex = len(content)
|
||||||
|
} else {
|
||||||
|
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||||
|
}
|
||||||
|
contentPiece := content[contentIndex:nextContentIndex]
|
||||||
|
contentIndex = nextContentIndex
|
||||||
|
log.Debugf("current content piece: %s", contentPiece)
|
||||||
|
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)
|
||||||
|
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed call the safe check service: %v", err)
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
callbackForImage := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
imageIndex += 1
|
||||||
|
log.Info(string(responseBody))
|
||||||
|
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||||
|
if imageIndex < len(images) {
|
||||||
|
singleCallForImage()
|
||||||
|
} else {
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var response cfg.Response
|
||||||
|
err := json.Unmarshal(responseBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("%+v", err)
|
||||||
|
if imageIndex < len(images) {
|
||||||
|
singleCallForImage()
|
||||||
|
} else {
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
endTime := time.Now().UnixMilli()
|
||||||
|
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||||
|
if imageIndex >= len(images) {
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
} else {
|
||||||
|
singleCallForImage()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
denyMessage := cfg.DefaultDenyMessage
|
||||||
|
if config.DenyMessage != "" {
|
||||||
|
denyMessage = config.DenyMessage
|
||||||
|
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||||
|
denyMessage = response.Data.Advice[0].Answer
|
||||||
|
}
|
||||||
|
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||||
|
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||||
|
ctx.DontReadResponseBody()
|
||||||
|
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||||
|
if response.Data.Advice != nil {
|
||||||
|
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||||
|
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||||
|
}
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
}
|
||||||
|
singleCallForImage = func() {
|
||||||
|
img := images[imageIndex]
|
||||||
|
imgUrl := ""
|
||||||
|
imgBase64 := ""
|
||||||
|
if img.Type == "BASE64" {
|
||||||
|
imgBase64 = img.Content
|
||||||
|
} else {
|
||||||
|
imgUrl = img.Content
|
||||||
|
}
|
||||||
|
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
|
||||||
|
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed call the safe check service: %v", err)
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(content) > 0 {
|
||||||
|
singleCall()
|
||||||
|
} else {
|
||||||
|
singleCallForImage()
|
||||||
|
}
|
||||||
|
return types.ActionPause
|
||||||
|
}
|
||||||
|
|
||||||
func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||||
consumer, _ := ctx.GetContext("consumer").(string)
|
consumer, _ := ctx.GetContext("consumer").(string)
|
||||||
log.Debugf("checking response body...")
|
log.Debugf("checking response body...")
|
||||||
checkImageService := config.GetResponseImageCheckService(consumer)
|
checkImageService := config.GetResponseImageCheckService(consumer)
|
||||||
startTime := time.Now().UnixMilli()
|
startTime := time.Now().UnixMilli()
|
||||||
imgUrls := getQwenImageUrls(body)
|
imgUrls := parseQwenResponse(body)
|
||||||
if len(imgUrls) == 0 {
|
if len(imgUrls) == 0 {
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
@@ -114,7 +402,14 @@ func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.A
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte("illegal image"), -1)
|
denyMessage := cfg.DefaultDenyMessage
|
||||||
|
if config.DenyMessage != "" {
|
||||||
|
denyMessage = config.DenyMessage
|
||||||
|
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||||
|
denyMessage = response.Data.Advice[0].Answer
|
||||||
|
}
|
||||||
|
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||||
|
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||||
|
|||||||
@@ -17,13 +17,13 @@ import (
|
|||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ImageItemForOpenAI struct {
|
type ImageItem struct {
|
||||||
Content string
|
Content string
|
||||||
Type string // URL or BASE64
|
Type string // URL or BASE64
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseContent(json gjson.Result) (text string, images []ImageItemForOpenAI) {
|
func parseContent(json gjson.Result) (text string, images []ImageItem) {
|
||||||
images = []ImageItemForOpenAI{}
|
images = []ImageItem{}
|
||||||
if json.IsArray() {
|
if json.IsArray() {
|
||||||
for _, item := range json.Array() {
|
for _, item := range json.Array() {
|
||||||
switch item.Get("type").String() {
|
switch item.Get("type").String() {
|
||||||
@@ -32,12 +32,12 @@ func parseContent(json gjson.Result) (text string, images []ImageItemForOpenAI)
|
|||||||
case "image_url":
|
case "image_url":
|
||||||
imgContent := item.Get("image_url.url").String()
|
imgContent := item.Get("image_url.url").String()
|
||||||
if strings.HasPrefix(imgContent, "data:image") {
|
if strings.HasPrefix(imgContent, "data:image") {
|
||||||
images = append(images, ImageItemForOpenAI{
|
images = append(images, ImageItem{
|
||||||
Content: imgContent,
|
Content: imgContent,
|
||||||
Type: "BASE64",
|
Type: "BASE64",
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
images = append(images, ImageItemForOpenAI{
|
images = append(images, ImageItem{
|
||||||
Content: imgContent,
|
Content: imgContent,
|
||||||
Type: "URL",
|
Type: "URL",
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user