mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 12:47:28 +08:00
[feat] ai-security-guard refactor & support checking multimoadl input (#3075)
This commit is contained in:
585
plugins/wasm-go/extensions/ai-security-guard/config/config.go
Normal file
585
plugins/wasm-go/extensions/ai-security-guard/config/config.go
Normal file
@@ -0,0 +1,585 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
MaxRisk = "max"
|
||||||
|
HighRisk = "high"
|
||||||
|
MediumRisk = "medium"
|
||||||
|
LowRisk = "low"
|
||||||
|
NoRisk = "none"
|
||||||
|
|
||||||
|
S4Sensitive = "s4"
|
||||||
|
S3Sensitive = "s3"
|
||||||
|
S2Sensitive = "s2"
|
||||||
|
S1Sensitive = "s1"
|
||||||
|
NoSensitive = "s0"
|
||||||
|
|
||||||
|
ContentModerationType = "contentModeration"
|
||||||
|
PromptAttackType = "promptAttack"
|
||||||
|
SensitiveDataType = "sensitiveData"
|
||||||
|
MaliciousUrlDataType = "maliciousUrl"
|
||||||
|
ModelHallucinationDataType = "modelHallucination"
|
||||||
|
|
||||||
|
// Default configurations
|
||||||
|
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||||
|
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
|
||||||
|
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||||
|
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
|
||||||
|
|
||||||
|
DefaultDenyCode = 200
|
||||||
|
DefaultDenyMessage = "很抱歉,我无法回答您的问题"
|
||||||
|
DefaultTimeout = 2000
|
||||||
|
|
||||||
|
AliyunUserAgent = "CIPFrom/AIGateway"
|
||||||
|
LengthLimit = 1800
|
||||||
|
|
||||||
|
DefaultRequestCheckService = "llm_query_moderation"
|
||||||
|
DefaultResponseCheckService = "llm_response_moderation"
|
||||||
|
DefaultRequestJsonPath = "messages.@reverse.0.content"
|
||||||
|
DefaultResponseJsonPath = "choices.0.message.content"
|
||||||
|
DefaultStreamingResponseJsonPath = "choices.0.delta.content"
|
||||||
|
|
||||||
|
// Actions
|
||||||
|
MultiModalGuard = "MultiModalGuard"
|
||||||
|
MultiModalGuardForBase64 = "MultiModalGuardForBase64"
|
||||||
|
TextModerationPlus = "TextModerationPlus"
|
||||||
|
|
||||||
|
// Services
|
||||||
|
DefaultMultiModalGuardTextInputCheckService = "query_security_check"
|
||||||
|
DefaultMultiModalGuardTextOutputCheckService = "response_security_check"
|
||||||
|
DefaultMultiModalGuardImageInputCheckService = "img_query_security_check"
|
||||||
|
|
||||||
|
DefaultTextModerationPlusTextInputCheckService = "llm_query_moderation"
|
||||||
|
DefaultTextModerationPlusTextOutputCheckService = "llm_response_moderation"
|
||||||
|
)
|
||||||
|
|
||||||
|
// api types
|
||||||
|
|
||||||
|
const (
|
||||||
|
ApiTextGeneration = "text_generation"
|
||||||
|
ApiImageGeneration = "image_generation"
|
||||||
|
)
|
||||||
|
|
||||||
|
// provider types
|
||||||
|
const (
|
||||||
|
ProviderOpenAI = "openai"
|
||||||
|
ProviderQwen = "qwen"
|
||||||
|
ProviderComfyUI = "comfyui"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Response struct {
|
||||||
|
Code int `json:"Code"`
|
||||||
|
Message string `json:"Message"`
|
||||||
|
RequestId string `json:"RequestId"`
|
||||||
|
Data Data `json:"Data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Data struct {
|
||||||
|
RiskLevel string `json:"RiskLevel,omitempty"`
|
||||||
|
AttackLevel string `json:"AttackLevel,omitempty"`
|
||||||
|
Result []Result `json:"Result,omitempty"`
|
||||||
|
Advice []Advice `json:"Advice,omitempty"`
|
||||||
|
Detail []Detail `json:"Detail,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
RiskWords string `json:"RiskWords,omitempty"`
|
||||||
|
Description string `json:"Description,omitempty"`
|
||||||
|
Confidence float64 `json:"Confidence,omitempty"`
|
||||||
|
Label string `json:"Label,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Advice struct {
|
||||||
|
Answer string `json:"Answer,omitempty"`
|
||||||
|
HitLabel string `json:"HitLabel,omitempty"`
|
||||||
|
HitLibName string `json:"HitLibName,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Detail struct {
|
||||||
|
Suggestion string `json:"Suggestion,omitempty"`
|
||||||
|
Type string `json:"Type,omitempty"`
|
||||||
|
Level string `json:"Level,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Matcher struct {
|
||||||
|
Exact string
|
||||||
|
Prefix string
|
||||||
|
Re *regexp.Regexp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Matcher) match(consumer string) bool {
|
||||||
|
if m.Exact != "" {
|
||||||
|
return consumer == m.Exact
|
||||||
|
} else if m.Prefix != "" {
|
||||||
|
return strings.HasPrefix(consumer, m.Prefix)
|
||||||
|
} else if m.Re != nil {
|
||||||
|
return m.Re.MatchString(consumer)
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type AISecurityConfig struct {
|
||||||
|
Client wrapper.HttpClient
|
||||||
|
Host string
|
||||||
|
AK string
|
||||||
|
SK string
|
||||||
|
Token string
|
||||||
|
Action string
|
||||||
|
CheckRequest bool
|
||||||
|
CheckRequestImage bool
|
||||||
|
RequestCheckService string
|
||||||
|
RequestImageCheckService string
|
||||||
|
RequestContentJsonPath string
|
||||||
|
CheckResponse bool
|
||||||
|
ResponseCheckService string
|
||||||
|
ResponseImageCheckService string
|
||||||
|
ResponseContentJsonPath string
|
||||||
|
ResponseStreamContentJsonPath string
|
||||||
|
DenyCode int64
|
||||||
|
DenyMessage string
|
||||||
|
ProtocolOriginal bool
|
||||||
|
RiskLevelBar string
|
||||||
|
ContentModerationLevelBar string
|
||||||
|
PromptAttackLevelBar string
|
||||||
|
SensitiveDataLevelBar string
|
||||||
|
MaliciousUrlLevelBar string
|
||||||
|
ModelHallucinationLevelBar string
|
||||||
|
Timeout uint32
|
||||||
|
BufferLimit int
|
||||||
|
Metrics map[string]proxywasm.MetricCounter
|
||||||
|
ConsumerRequestCheckService []map[string]interface{}
|
||||||
|
ConsumerResponseCheckService []map[string]interface{}
|
||||||
|
ConsumerRiskLevel []map[string]interface{}
|
||||||
|
// text_generation, image_generation, etc.
|
||||||
|
ApiType string
|
||||||
|
// openai, qwen, comfyui, etc.
|
||||||
|
ProviderType string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (config *AISecurityConfig) Parse(json gjson.Result) error {
|
||||||
|
serviceName := json.Get("serviceName").String()
|
||||||
|
servicePort := json.Get("servicePort").Int()
|
||||||
|
serviceHost := json.Get("serviceHost").String()
|
||||||
|
config.Host = serviceHost
|
||||||
|
if serviceName == "" || servicePort == 0 || serviceHost == "" {
|
||||||
|
return errors.New("invalid service config")
|
||||||
|
}
|
||||||
|
config.AK = json.Get("accessKey").String()
|
||||||
|
config.SK = json.Get("secretKey").String()
|
||||||
|
if config.AK == "" || config.SK == "" {
|
||||||
|
return errors.New("invalid AK/SK config")
|
||||||
|
}
|
||||||
|
config.Token = json.Get("securityToken").String()
|
||||||
|
// set action
|
||||||
|
if obj := json.Get("action"); obj.Exists() {
|
||||||
|
config.Action = json.Get("action").String()
|
||||||
|
} else {
|
||||||
|
config.Action = TextModerationPlus
|
||||||
|
}
|
||||||
|
// set default values
|
||||||
|
config.SetDefaultValues()
|
||||||
|
// set values
|
||||||
|
if obj := json.Get("riskLevelBar"); obj.Exists() {
|
||||||
|
config.RiskLevelBar = obj.String()
|
||||||
|
}
|
||||||
|
if obj := json.Get("requestCheckService"); obj.Exists() {
|
||||||
|
config.RequestCheckService = obj.String()
|
||||||
|
}
|
||||||
|
if obj := json.Get("requestImageCheckService"); obj.Exists() {
|
||||||
|
config.RequestImageCheckService = obj.String()
|
||||||
|
}
|
||||||
|
if obj := json.Get("responseCheckService"); obj.Exists() {
|
||||||
|
config.ResponseCheckService = obj.String()
|
||||||
|
}
|
||||||
|
if obj := json.Get("responseImageCheckService"); obj.Exists() {
|
||||||
|
config.ResponseImageCheckService = obj.String()
|
||||||
|
}
|
||||||
|
config.CheckRequest = json.Get("checkRequest").Bool()
|
||||||
|
config.CheckRequestImage = json.Get("checkRequestImage").Bool()
|
||||||
|
config.CheckResponse = json.Get("checkResponse").Bool()
|
||||||
|
config.ProtocolOriginal = json.Get("protocol").String() == "original"
|
||||||
|
config.DenyMessage = json.Get("denyMessage").String()
|
||||||
|
if obj := json.Get("denyCode"); obj.Exists() {
|
||||||
|
config.DenyCode = obj.Int()
|
||||||
|
}
|
||||||
|
if obj := json.Get("requestContentJsonPath"); obj.Exists() {
|
||||||
|
config.RequestContentJsonPath = obj.String()
|
||||||
|
}
|
||||||
|
if obj := json.Get("responseContentJsonPath"); obj.Exists() {
|
||||||
|
config.ResponseContentJsonPath = obj.String()
|
||||||
|
}
|
||||||
|
if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() {
|
||||||
|
config.ResponseStreamContentJsonPath = obj.String()
|
||||||
|
}
|
||||||
|
if obj := json.Get("contentModerationLevelBar"); obj.Exists() {
|
||||||
|
config.ContentModerationLevelBar = obj.String()
|
||||||
|
if LevelToInt(config.ContentModerationLevelBar) <= 0 {
|
||||||
|
return errors.New("invalid contentModerationLevelBar, value must be one of [max, high, medium, low]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if obj := json.Get("promptAttackLevelBar"); obj.Exists() {
|
||||||
|
config.PromptAttackLevelBar = obj.String()
|
||||||
|
if LevelToInt(config.PromptAttackLevelBar) <= 0 {
|
||||||
|
return errors.New("invalid promptAttackLevelBar, value must be one of [max, high, medium, low]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if obj := json.Get("sensitiveDataLevelBar"); obj.Exists() {
|
||||||
|
config.SensitiveDataLevelBar = obj.String()
|
||||||
|
if LevelToInt(config.SensitiveDataLevelBar) <= 0 {
|
||||||
|
return errors.New("invalid sensitiveDataLevelBar, value must be one of [S4, S3, S2, S1]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if obj := json.Get("modelHallucinationLevelBar"); obj.Exists() {
|
||||||
|
config.ModelHallucinationLevelBar = obj.String()
|
||||||
|
if LevelToInt(config.ModelHallucinationLevelBar) <= 0 {
|
||||||
|
return errors.New("invalid modelHallucinationLevelBar, value must be one of [max, high, medium, low]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if obj := json.Get("maliciousUrlLevelBar"); obj.Exists() {
|
||||||
|
config.MaliciousUrlLevelBar = obj.String()
|
||||||
|
if LevelToInt(config.MaliciousUrlLevelBar) <= 0 {
|
||||||
|
return errors.New("invalid maliciousUrlLevelBar, value must be one of [max, high, medium, low]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if obj := json.Get("timeout"); obj.Exists() {
|
||||||
|
config.Timeout = uint32(obj.Int())
|
||||||
|
}
|
||||||
|
if obj := json.Get("bufferLimit"); obj.Exists() {
|
||||||
|
config.BufferLimit = int(obj.Int())
|
||||||
|
}
|
||||||
|
if obj := json.Get("consumerRequestCheckService"); obj.Exists() {
|
||||||
|
for _, item := range json.Get("consumerRequestCheckService").Array() {
|
||||||
|
m := make(map[string]interface{})
|
||||||
|
for k, v := range item.Map() {
|
||||||
|
m[k] = v.Value()
|
||||||
|
}
|
||||||
|
consumerName, ok1 := m["name"]
|
||||||
|
matchType, ok2 := m["matchType"]
|
||||||
|
if !ok1 || !ok2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch fmt.Sprint(matchType) {
|
||||||
|
case "exact":
|
||||||
|
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
|
||||||
|
case "prefix":
|
||||||
|
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
|
||||||
|
case "regexp":
|
||||||
|
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
|
||||||
|
}
|
||||||
|
config.ConsumerRequestCheckService = append(config.ConsumerRequestCheckService, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if obj := json.Get("consumerResponseCheckService"); obj.Exists() {
|
||||||
|
for _, item := range json.Get("consumerResponseCheckService").Array() {
|
||||||
|
m := make(map[string]interface{})
|
||||||
|
for k, v := range item.Map() {
|
||||||
|
m[k] = v.Value()
|
||||||
|
}
|
||||||
|
consumerName, ok1 := m["name"]
|
||||||
|
matchType, ok2 := m["matchType"]
|
||||||
|
if !ok1 || !ok2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch fmt.Sprint(matchType) {
|
||||||
|
case "exact":
|
||||||
|
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
|
||||||
|
case "prefix":
|
||||||
|
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
|
||||||
|
case "regexp":
|
||||||
|
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
|
||||||
|
}
|
||||||
|
config.ConsumerResponseCheckService = append(config.ConsumerResponseCheckService, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if obj := json.Get("consumerRiskLevel"); obj.Exists() {
|
||||||
|
for _, item := range json.Get("consumerRiskLevel").Array() {
|
||||||
|
m := make(map[string]interface{})
|
||||||
|
for k, v := range item.Map() {
|
||||||
|
m[k] = v.Value()
|
||||||
|
}
|
||||||
|
consumerName, ok1 := m["name"]
|
||||||
|
matchType, ok2 := m["matchType"]
|
||||||
|
if !ok1 || !ok2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch fmt.Sprint(matchType) {
|
||||||
|
case "exact":
|
||||||
|
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
|
||||||
|
case "prefix":
|
||||||
|
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
|
||||||
|
case "regexp":
|
||||||
|
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
|
||||||
|
}
|
||||||
|
config.ConsumerRiskLevel = append(config.ConsumerRiskLevel, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if obj := json.Get("apiType"); obj.Exists() {
|
||||||
|
config.ApiType = obj.String()
|
||||||
|
}
|
||||||
|
if obj := json.Get("providerType"); obj.Exists() {
|
||||||
|
config.ProviderType = obj.String()
|
||||||
|
}
|
||||||
|
config.Client = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||||
|
FQDN: serviceName,
|
||||||
|
Port: servicePort,
|
||||||
|
Host: serviceHost,
|
||||||
|
})
|
||||||
|
config.Metrics = make(map[string]proxywasm.MetricCounter)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (config *AISecurityConfig) SetDefaultValues() {
|
||||||
|
switch config.Action {
|
||||||
|
case TextModerationPlus:
|
||||||
|
config.RequestCheckService = DefaultTextModerationPlusTextInputCheckService
|
||||||
|
config.ResponseCheckService = DefaultTextModerationPlusTextOutputCheckService
|
||||||
|
case MultiModalGuard:
|
||||||
|
config.RequestCheckService = DefaultMultiModalGuardTextInputCheckService
|
||||||
|
config.RequestImageCheckService = DefaultMultiModalGuardImageInputCheckService
|
||||||
|
config.ResponseCheckService = DefaultMultiModalGuardTextOutputCheckService
|
||||||
|
}
|
||||||
|
config.RiskLevelBar = HighRisk
|
||||||
|
config.DenyCode = DefaultDenyCode
|
||||||
|
config.RequestContentJsonPath = DefaultRequestJsonPath
|
||||||
|
config.ResponseContentJsonPath = DefaultResponseJsonPath
|
||||||
|
config.ResponseStreamContentJsonPath = DefaultStreamingResponseJsonPath
|
||||||
|
config.ContentModerationLevelBar = MaxRisk
|
||||||
|
config.PromptAttackLevelBar = MaxRisk
|
||||||
|
config.SensitiveDataLevelBar = S4Sensitive
|
||||||
|
config.ModelHallucinationLevelBar = MaxRisk
|
||||||
|
config.MaliciousUrlLevelBar = MaxRisk
|
||||||
|
config.Timeout = DefaultTimeout
|
||||||
|
config.BufferLimit = 1000
|
||||||
|
config.ApiType = ApiTextGeneration
|
||||||
|
config.ProviderType = ProviderOpenAI
|
||||||
|
}
|
||||||
|
|
||||||
|
func (config *AISecurityConfig) IncrementCounter(metricName string, inc uint64) {
|
||||||
|
counter, ok := config.Metrics[metricName]
|
||||||
|
if !ok {
|
||||||
|
counter = proxywasm.DefineCounterMetric(metricName)
|
||||||
|
config.Metrics[metricName] = counter
|
||||||
|
}
|
||||||
|
counter.Increment(inc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (config *AISecurityConfig) GetRequestCheckService(consumer string) string {
|
||||||
|
result := config.RequestCheckService
|
||||||
|
for _, obj := range config.ConsumerRequestCheckService {
|
||||||
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||||
|
if matcher.match(consumer) {
|
||||||
|
if requestCheckService, ok := obj["requestCheckService"]; ok {
|
||||||
|
result, _ = requestCheckService.(string)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (config *AISecurityConfig) GetRequestImageCheckService(consumer string) string {
|
||||||
|
result := config.RequestImageCheckService
|
||||||
|
for _, obj := range config.ConsumerRequestCheckService {
|
||||||
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||||
|
if matcher.match(consumer) {
|
||||||
|
if requestCheckService, ok := obj["requestImageCheckService"]; ok {
|
||||||
|
result, _ = requestCheckService.(string)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (config *AISecurityConfig) GetResponseCheckService(consumer string) string {
|
||||||
|
result := config.ResponseCheckService
|
||||||
|
for _, obj := range config.ConsumerResponseCheckService {
|
||||||
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||||
|
if matcher.match(consumer) {
|
||||||
|
if responseCheckService, ok := obj["responseCheckService"]; ok {
|
||||||
|
result, _ = responseCheckService.(string)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (config *AISecurityConfig) GetResponseImageCheckService(consumer string) string {
|
||||||
|
result := config.ResponseImageCheckService
|
||||||
|
for _, obj := range config.ConsumerResponseCheckService {
|
||||||
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||||
|
if matcher.match(consumer) {
|
||||||
|
if responseCheckService, ok := obj["responseImageCheckService"]; ok {
|
||||||
|
result, _ = responseCheckService.(string)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (config *AISecurityConfig) GetRiskLevelBar(consumer string) string {
|
||||||
|
result := config.RiskLevelBar
|
||||||
|
for _, obj := range config.ConsumerRiskLevel {
|
||||||
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||||
|
if matcher.match(consumer) {
|
||||||
|
if riskLevelBar, ok := obj["riskLevelBar"]; ok {
|
||||||
|
result, _ = riskLevelBar.(string)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (config *AISecurityConfig) GetContentModerationLevelBar(consumer string) string {
|
||||||
|
result := config.ContentModerationLevelBar
|
||||||
|
for _, obj := range config.ConsumerRiskLevel {
|
||||||
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||||
|
if matcher.match(consumer) {
|
||||||
|
if contentModerationLevelBar, ok := obj["contentModerationLevelBar"]; ok {
|
||||||
|
result, _ = contentModerationLevelBar.(string)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (config *AISecurityConfig) GetPromptAttackLevelBar(consumer string) string {
|
||||||
|
result := config.PromptAttackLevelBar
|
||||||
|
for _, obj := range config.ConsumerRiskLevel {
|
||||||
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||||
|
if matcher.match(consumer) {
|
||||||
|
if promptAttackLevelBar, ok := obj["promptAttackLevelBar"]; ok {
|
||||||
|
result, _ = promptAttackLevelBar.(string)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (config *AISecurityConfig) GetSensitiveDataLevelBar(consumer string) string {
|
||||||
|
result := config.SensitiveDataLevelBar
|
||||||
|
for _, obj := range config.ConsumerRiskLevel {
|
||||||
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||||
|
if matcher.match(consumer) {
|
||||||
|
if sensitiveDataLevelBar, ok := obj["sensitiveDataLevelBar"]; ok {
|
||||||
|
result, _ = sensitiveDataLevelBar.(string)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (config *AISecurityConfig) GetMaliciousUrlLevelBar(consumer string) string {
|
||||||
|
result := config.MaliciousUrlLevelBar
|
||||||
|
for _, obj := range config.ConsumerRiskLevel {
|
||||||
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||||
|
if matcher.match(consumer) {
|
||||||
|
if maliciousUrlLevelBar, ok := obj["maliciousUrlLevelBar"]; ok {
|
||||||
|
result, _ = maliciousUrlLevelBar.(string)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (config *AISecurityConfig) GetModelHallucinationLevelBar(consumer string) string {
|
||||||
|
result := config.ModelHallucinationLevelBar
|
||||||
|
for _, obj := range config.ConsumerRiskLevel {
|
||||||
|
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||||
|
if matcher.match(consumer) {
|
||||||
|
if modelHallucinationLevelBar, ok := obj["modelHallucinationLevelBar"]; ok {
|
||||||
|
result, _ = modelHallucinationLevelBar.(string)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func LevelToInt(riskLevel string) int {
|
||||||
|
// First check against our defined constants
|
||||||
|
switch strings.ToLower(riskLevel) {
|
||||||
|
case MaxRisk, S4Sensitive:
|
||||||
|
return 4
|
||||||
|
case HighRisk, S3Sensitive:
|
||||||
|
return 3
|
||||||
|
case MediumRisk, S2Sensitive:
|
||||||
|
return 2
|
||||||
|
case LowRisk, S1Sensitive:
|
||||||
|
return 1
|
||||||
|
case NoRisk, NoSensitive:
|
||||||
|
return 0
|
||||||
|
default:
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsRiskLevelAcceptable(action string, data Data, config AISecurityConfig, consumer string) bool {
|
||||||
|
if action == MultiModalGuard || action == MultiModalGuardForBase64 {
|
||||||
|
// Check top-level risk levels for MultiModalGuard
|
||||||
|
if LevelToInt(data.RiskLevel) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Also check AttackLevel for prompt attack detection
|
||||||
|
if LevelToInt(data.AttackLevel) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check detailed results for backward compatibility
|
||||||
|
for _, detail := range data.Detail {
|
||||||
|
switch detail.Type {
|
||||||
|
case ContentModerationType:
|
||||||
|
if LevelToInt(detail.Level) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case PromptAttackType:
|
||||||
|
if LevelToInt(detail.Level) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case SensitiveDataType:
|
||||||
|
if LevelToInt(detail.Level) >= LevelToInt(config.GetSensitiveDataLevelBar(consumer)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case MaliciousUrlDataType:
|
||||||
|
if LevelToInt(detail.Level) >= LevelToInt(config.GetMaliciousUrlLevelBar(consumer)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case ModelHallucinationDataType:
|
||||||
|
if LevelToInt(detail.Level) >= LevelToInt(config.GetModelHallucinationLevelBar(consumer)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
return LevelToInt(data.RiskLevel) < LevelToInt(config.GetRiskLevelBar(consumer))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,8 +5,8 @@ go 1.24.1
|
|||||||
toolchain go1.24.4
|
toolchain go1.24.4
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
|
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
github.com/tidwall/gjson v1.18.0
|
github.com/tidwall/gjson v1.18.0
|
||||||
)
|
)
|
||||||
@@ -20,5 +20,6 @@ require (
|
|||||||
github.com/tidwall/pretty v1.2.1 // indirect
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
github.com/tidwall/resp v0.1.1 // indirect
|
github.com/tidwall/resp v0.1.1 // indirect
|
||||||
github.com/tidwall/sjson v1.2.5 // indirect
|
github.com/tidwall/sjson v1.2.5 // indirect
|
||||||
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,8 +4,12 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
|||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||||
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
|
||||||
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
||||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||||
|
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd h1:acTs8sqXf+qP+IypxFg3cu5Cluj7VT5BI+IDRlY5sag=
|
||||||
|
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
@@ -24,6 +28,8 @@ github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
|||||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||||
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
|||||||
@@ -0,0 +1,249 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ALGORITHM = "ACS3-HMAC-SHA256"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Request struct {
|
||||||
|
httpMethod string
|
||||||
|
canonicalUri string
|
||||||
|
host string
|
||||||
|
xAcsAction string
|
||||||
|
xAcsVersion string
|
||||||
|
headers map[string]string
|
||||||
|
body []byte
|
||||||
|
queryParam map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRequest(httpMethod, canonicalUri, host, xAcsAction, xAcsVersion string) *Request {
|
||||||
|
req := &Request{
|
||||||
|
httpMethod: httpMethod,
|
||||||
|
canonicalUri: canonicalUri,
|
||||||
|
host: host,
|
||||||
|
xAcsAction: xAcsAction,
|
||||||
|
xAcsVersion: xAcsVersion,
|
||||||
|
headers: make(map[string]string),
|
||||||
|
queryParam: make(map[string]interface{}),
|
||||||
|
}
|
||||||
|
req.headers["host"] = host
|
||||||
|
req.headers["x-acs-action"] = xAcsAction
|
||||||
|
req.headers["x-acs-version"] = xAcsVersion
|
||||||
|
req.headers["x-acs-date"] = time.Now().UTC().Format(time.RFC3339)
|
||||||
|
req.headers["x-acs-signature-nonce"] = uuid.New().String()
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAuthorization(req *Request, AccessKeyId, AccessKeySecret, SecurityToken string) {
|
||||||
|
newQueryParams := make(map[string]interface{})
|
||||||
|
processObject(newQueryParams, "", req.queryParam)
|
||||||
|
req.queryParam = newQueryParams
|
||||||
|
canonicalQueryString := ""
|
||||||
|
keys := maps.Keys(req.queryParam)
|
||||||
|
sort.Strings(keys)
|
||||||
|
for _, k := range keys {
|
||||||
|
v := req.queryParam[k]
|
||||||
|
canonicalQueryString += percentCode(url.QueryEscape(k)) + "=" + percentCode(url.QueryEscape(fmt.Sprintf("%v", v))) + "&"
|
||||||
|
}
|
||||||
|
canonicalQueryString = strings.TrimSuffix(canonicalQueryString, "&")
|
||||||
|
|
||||||
|
var bodyContent []byte
|
||||||
|
if req.body == nil {
|
||||||
|
bodyContent = []byte("")
|
||||||
|
} else {
|
||||||
|
bodyContent = req.body
|
||||||
|
}
|
||||||
|
hashedRequestPayload := sha256Hex(bodyContent)
|
||||||
|
req.headers["x-acs-content-sha256"] = hashedRequestPayload
|
||||||
|
|
||||||
|
if SecurityToken != "" {
|
||||||
|
req.headers["x-acs-security-token"] = SecurityToken
|
||||||
|
}
|
||||||
|
|
||||||
|
canonicalHeaders := ""
|
||||||
|
signedHeaders := ""
|
||||||
|
HeadersKeys := maps.Keys(req.headers)
|
||||||
|
sort.Strings(HeadersKeys)
|
||||||
|
for _, k := range HeadersKeys {
|
||||||
|
lowerKey := strings.ToLower(k)
|
||||||
|
if lowerKey == "host" || strings.HasPrefix(lowerKey, "x-acs-") || lowerKey == "content-type" {
|
||||||
|
canonicalHeaders += lowerKey + ":" + req.headers[k] + "\n"
|
||||||
|
signedHeaders += lowerKey + ";"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
signedHeaders = strings.TrimSuffix(signedHeaders, ";")
|
||||||
|
|
||||||
|
canonicalRequest := req.httpMethod + "\n" + req.canonicalUri + "\n" + canonicalQueryString + "\n" + canonicalHeaders + "\n" + signedHeaders + "\n" + hashedRequestPayload
|
||||||
|
|
||||||
|
hashedCanonicalRequest := sha256Hex([]byte(canonicalRequest))
|
||||||
|
stringToSign := ALGORITHM + "\n" + hashedCanonicalRequest
|
||||||
|
|
||||||
|
byteData, err := hmac256([]byte(AccessKeySecret), stringToSign)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
signature := strings.ToLower(hex.EncodeToString(byteData))
|
||||||
|
|
||||||
|
authorization := ALGORITHM + " Credential=" + AccessKeyId + ",SignedHeaders=" + signedHeaders + ",Signature=" + signature
|
||||||
|
req.headers["Authorization"] = authorization
|
||||||
|
}
|
||||||
|
|
||||||
|
func hmac256(key []byte, toSignString string) ([]byte, error) {
|
||||||
|
h := hmac.New(sha256.New, key)
|
||||||
|
_, err := h.Write([]byte(toSignString))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return h.Sum(nil), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sha256Hex(byteArray []byte) string {
|
||||||
|
hash := sha256.New()
|
||||||
|
_, _ = hash.Write(byteArray)
|
||||||
|
hexString := hex.EncodeToString(hash.Sum(nil))
|
||||||
|
return hexString
|
||||||
|
}
|
||||||
|
|
||||||
|
func percentCode(str string) string {
|
||||||
|
str = strings.ReplaceAll(str, "+", "%20")
|
||||||
|
str = strings.ReplaceAll(str, "*", "%2A")
|
||||||
|
str = strings.ReplaceAll(str, "%7E", "~")
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
func formDataToString(formData map[string]interface{}) *string {
|
||||||
|
tmp := make(map[string]interface{})
|
||||||
|
processObject(tmp, "", formData)
|
||||||
|
res := ""
|
||||||
|
urlEncoder := url.Values{}
|
||||||
|
for key, value := range tmp {
|
||||||
|
v := fmt.Sprintf("%v", value)
|
||||||
|
urlEncoder.Add(key, v)
|
||||||
|
}
|
||||||
|
res = urlEncoder.Encode()
|
||||||
|
return &res
|
||||||
|
}
|
||||||
|
|
||||||
|
// processObject 递归处理对象,将复杂对象(如Map和List)展开为平面的键值对
|
||||||
|
func processObject(mapResult map[string]interface{}, key string, value interface{}) {
|
||||||
|
if value == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case []interface{}:
|
||||||
|
for i, item := range v {
|
||||||
|
processObject(mapResult, fmt.Sprintf("%s.%d", key, i+1), item)
|
||||||
|
}
|
||||||
|
case map[string]interface{}:
|
||||||
|
for subKey, subValue := range v {
|
||||||
|
processObject(mapResult, fmt.Sprintf("%s.%s", key, subKey), subValue)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if strings.HasPrefix(key, ".") {
|
||||||
|
key = key[1:]
|
||||||
|
}
|
||||||
|
if b, ok := v.([]byte); ok {
|
||||||
|
mapResult[key] = string(b)
|
||||||
|
} else {
|
||||||
|
mapResult[key] = fmt.Sprintf("%v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateRequestForText(config cfg.AISecurityConfig, checkAction, checkService, text, sessionID string) (path string, headers [][2]string, reqBody []byte) {
|
||||||
|
httpMethod := "POST"
|
||||||
|
canonicalUri := "/"
|
||||||
|
xAcsVersion := "2022-03-02"
|
||||||
|
req := newRequest(httpMethod, canonicalUri, config.Host, checkAction, xAcsVersion)
|
||||||
|
|
||||||
|
req.queryParam["Service"] = checkService
|
||||||
|
|
||||||
|
body := make(map[string]interface{})
|
||||||
|
serviceParameters := make(map[string]interface{})
|
||||||
|
serviceParameters["content"] = text
|
||||||
|
serviceParameters["sessionId"] = sessionID
|
||||||
|
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
|
||||||
|
serviceParametersJSON, _ := json.Marshal(serviceParameters)
|
||||||
|
body["ServiceParameters"] = serviceParametersJSON
|
||||||
|
str := formDataToString(body)
|
||||||
|
req.body = []byte(*str)
|
||||||
|
req.headers["content-type"] = "application/x-www-form-urlencoded"
|
||||||
|
req.headers["User-Agent"] = cfg.AliyunUserAgent
|
||||||
|
|
||||||
|
getAuthorization(req, config.AK, config.SK, config.Token)
|
||||||
|
|
||||||
|
q := url.Values{}
|
||||||
|
keys := maps.Keys(req.queryParam)
|
||||||
|
sort.Strings(keys)
|
||||||
|
for _, k := range keys {
|
||||||
|
v := req.queryParam[k]
|
||||||
|
q.Set(k, fmt.Sprintf("%v", v))
|
||||||
|
}
|
||||||
|
for k, v := range req.headers {
|
||||||
|
if k != "host" {
|
||||||
|
headers = append(headers, [2]string{k, v})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "?" + q.Encode(), headers, req.body
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateRequestForImage(config cfg.AISecurityConfig, checkAction, checkService, imgUrl, imgBase64 string) (path string, headers [][2]string, reqBody []byte) {
|
||||||
|
httpMethod := "POST"
|
||||||
|
canonicalUri := "/"
|
||||||
|
xAcsVersion := "2022-03-02"
|
||||||
|
req := newRequest(httpMethod, canonicalUri, config.Host, checkAction, xAcsVersion)
|
||||||
|
|
||||||
|
req.queryParam["Service"] = checkService
|
||||||
|
|
||||||
|
body := make(map[string]interface{})
|
||||||
|
serviceParameters := make(map[string]interface{})
|
||||||
|
if imgUrl != "" {
|
||||||
|
serviceParameters["imageUrls"] = []string{imgUrl}
|
||||||
|
}
|
||||||
|
serviceParametersJSON, _ := json.Marshal(serviceParameters)
|
||||||
|
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
|
||||||
|
body["ServiceParameters"] = serviceParametersJSON
|
||||||
|
if imgBase64 != "" {
|
||||||
|
body["ImageBase64Str"] = imgBase64
|
||||||
|
}
|
||||||
|
str := formDataToString(body)
|
||||||
|
req.body = []byte(*str)
|
||||||
|
req.headers["content-type"] = "application/x-www-form-urlencoded"
|
||||||
|
req.headers["User-Agent"] = cfg.AliyunUserAgent
|
||||||
|
|
||||||
|
getAuthorization(req, config.AK, config.SK, config.Token)
|
||||||
|
|
||||||
|
q := url.Values{}
|
||||||
|
keys := maps.Keys(req.queryParam)
|
||||||
|
sort.Strings(keys)
|
||||||
|
for _, k := range keys {
|
||||||
|
v := req.queryParam[k]
|
||||||
|
q.Set(k, fmt.Sprintf("%v", v))
|
||||||
|
}
|
||||||
|
for k, v := range req.headers {
|
||||||
|
// host will be added by envoy automatically
|
||||||
|
if k != "host" {
|
||||||
|
headers = append(headers, [2]string{k, v})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "?" + q.Encode(), headers, req.body
|
||||||
|
}
|
||||||
@@ -0,0 +1,249 @@
|
|||||||
|
package text
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||||
|
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||||
|
ctx.SetContext("end_of_stream_received", false)
|
||||||
|
ctx.SetContext("during_call", false)
|
||||||
|
ctx.SetContext("risk_detected", false)
|
||||||
|
sessionID, _ := utils.GenerateHexID(20)
|
||||||
|
ctx.SetContext("sessionID", sessionID)
|
||||||
|
if strings.Contains(contentType, "text/event-stream") {
|
||||||
|
ctx.NeedPauseStreamingResponse()
|
||||||
|
return types.ActionContinue
|
||||||
|
} else {
|
||||||
|
ctx.BufferResponseBody()
|
||||||
|
return types.HeaderStopIteration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
|
||||||
|
consumer, _ := ctx.GetContext("consumer").(string)
|
||||||
|
var sessionID string
|
||||||
|
if ctx.GetContext("sessionID") == nil {
|
||||||
|
sessionID, _ = utils.GenerateHexID(20)
|
||||||
|
ctx.SetContext("sessionID", sessionID)
|
||||||
|
} else {
|
||||||
|
sessionID, _ = ctx.GetContext("sessionID").(string)
|
||||||
|
}
|
||||||
|
var bufferQueue [][]byte
|
||||||
|
var singleCall func()
|
||||||
|
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Info(string(responseBody))
|
||||||
|
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||||
|
if ctx.GetContext("end_of_stream_received").(bool) {
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
}
|
||||||
|
ctx.SetContext("during_call", false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var response cfg.Response
|
||||||
|
err := json.Unmarshal(responseBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("failed to unmarshal aliyun content security response at response phase")
|
||||||
|
if ctx.GetContext("end_of_stream_received").(bool) {
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
}
|
||||||
|
ctx.SetContext("during_call", false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||||
|
denyMessage := cfg.DefaultDenyMessage
|
||||||
|
if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||||
|
denyMessage = "\n" + response.Data.Advice[0].Answer
|
||||||
|
} else if config.DenyMessage != "" {
|
||||||
|
denyMessage = config.DenyMessage
|
||||||
|
}
|
||||||
|
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||||
|
randomID := utils.GenerateRandomChatID()
|
||||||
|
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||||
|
proxywasm.InjectEncodedDataToFilterChain(jsonData, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
endStream := ctx.GetContext("end_of_stream_received").(bool) && ctx.BufferQueueSize() == 0
|
||||||
|
proxywasm.InjectEncodedDataToFilterChain(bytes.Join(bufferQueue, []byte("")), endStream)
|
||||||
|
bufferQueue = [][]byte{}
|
||||||
|
if !endStream {
|
||||||
|
ctx.SetContext("during_call", false)
|
||||||
|
singleCall()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
singleCall = func() {
|
||||||
|
if ctx.GetContext("during_call").(bool) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ctx.BufferQueueSize() >= config.BufferLimit || ctx.GetContext("end_of_stream_received").(bool) {
|
||||||
|
var buffer string
|
||||||
|
for ctx.BufferQueueSize() > 0 {
|
||||||
|
front := ctx.PopBuffer()
|
||||||
|
bufferQueue = append(bufferQueue, front)
|
||||||
|
msg := gjson.GetBytes(front, config.ResponseStreamContentJsonPath).String()
|
||||||
|
buffer += msg
|
||||||
|
if len([]rune(buffer)) >= config.BufferLimit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if streaming body has reasoning_content, buffer maybe empty
|
||||||
|
log.Debugf("current content piece: %s", buffer)
|
||||||
|
if len(buffer) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx.SetContext("during_call", true)
|
||||||
|
log.Debugf("current content piece: %s", buffer)
|
||||||
|
checkService := config.GetResponseCheckService(consumer)
|
||||||
|
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, buffer, sessionID)
|
||||||
|
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed call the safe check service: %v", err)
|
||||||
|
if ctx.GetContext("end_of_stream_received").(bool) {
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !ctx.GetContext("risk_detected").(bool) {
|
||||||
|
unifiedChunk := wrapper.UnifySSEChunk(data)
|
||||||
|
hasTrailingSeparator := bytes.HasSuffix(unifiedChunk, []byte("\n\n"))
|
||||||
|
trimmedChunk := bytes.TrimSpace(unifiedChunk)
|
||||||
|
chunks := bytes.Split(trimmedChunk, []byte("\n\n"))
|
||||||
|
// Filter out empty chunks
|
||||||
|
nonEmptyChunks := make([][]byte, 0, len(chunks))
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
if len(chunk) > 0 {
|
||||||
|
nonEmptyChunks = append(nonEmptyChunks, chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Restore separators
|
||||||
|
for i := range len(nonEmptyChunks) - 1 {
|
||||||
|
nonEmptyChunks[i] = append(nonEmptyChunks[i], []byte("\n\n")...)
|
||||||
|
}
|
||||||
|
if hasTrailingSeparator && len(nonEmptyChunks) > 0 {
|
||||||
|
nonEmptyChunks[len(nonEmptyChunks)-1] = append(nonEmptyChunks[len(nonEmptyChunks)-1], []byte("\n\n")...)
|
||||||
|
}
|
||||||
|
for _, chunk := range nonEmptyChunks {
|
||||||
|
ctx.PushBuffer(chunk)
|
||||||
|
}
|
||||||
|
// for _, chunk := range bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) {
|
||||||
|
// ctx.PushBuffer([]byte(string(chunk) + "\n\n"))
|
||||||
|
// }
|
||||||
|
ctx.SetContext("end_of_stream_received", endOfStream)
|
||||||
|
if !ctx.GetContext("during_call").(bool) {
|
||||||
|
singleCall()
|
||||||
|
}
|
||||||
|
} else if endOfStream {
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
}
|
||||||
|
return []byte{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||||
|
consumer, _ := ctx.GetContext("consumer").(string)
|
||||||
|
log.Debugf("checking response body...")
|
||||||
|
startTime := time.Now().UnixMilli()
|
||||||
|
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||||
|
isStreamingResponse := strings.Contains(contentType, "event-stream")
|
||||||
|
var content string
|
||||||
|
if isStreamingResponse {
|
||||||
|
content = utils.ExtractMessageFromStreamingBody(body, config.ResponseStreamContentJsonPath)
|
||||||
|
} else {
|
||||||
|
content = gjson.GetBytes(body, config.ResponseContentJsonPath).String()
|
||||||
|
}
|
||||||
|
log.Debugf("Raw response content is: %s", content)
|
||||||
|
if len(content) == 0 {
|
||||||
|
log.Info("response content is empty. skip")
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
contentIndex := 0
|
||||||
|
sessionID, _ := utils.GenerateHexID(20)
|
||||||
|
var singleCall func()
|
||||||
|
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Info(string(responseBody))
|
||||||
|
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var response cfg.Response
|
||||||
|
err := json.Unmarshal(responseBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("failed to unmarshal aliyun content security response at response phase")
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||||
|
if contentIndex >= len(content) {
|
||||||
|
endTime := time.Now().UnixMilli()
|
||||||
|
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "response pass")
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
} else {
|
||||||
|
singleCall()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
denyMessage := cfg.DefaultDenyMessage
|
||||||
|
if config.DenyMessage != "" {
|
||||||
|
denyMessage = config.DenyMessage
|
||||||
|
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||||
|
denyMessage = response.Data.Advice[0].Answer
|
||||||
|
}
|
||||||
|
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||||
|
if config.ProtocolOriginal {
|
||||||
|
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||||
|
} else if isStreamingResponse {
|
||||||
|
randomID := utils.GenerateRandomChatID()
|
||||||
|
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||||
|
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||||
|
} else {
|
||||||
|
randomID := utils.GenerateRandomChatID()
|
||||||
|
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||||
|
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||||
|
}
|
||||||
|
config.IncrementCounter("ai_sec_response_deny", 1)
|
||||||
|
endTime := time.Now().UnixMilli()
|
||||||
|
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "response deny")
|
||||||
|
if response.Data.Advice != nil {
|
||||||
|
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||||
|
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||||
|
}
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
}
|
||||||
|
singleCall = func() {
|
||||||
|
var nextContentIndex int
|
||||||
|
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||||
|
nextContentIndex = len(content)
|
||||||
|
} else {
|
||||||
|
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||||
|
}
|
||||||
|
contentPiece := content[contentIndex:nextContentIndex]
|
||||||
|
contentIndex = nextContentIndex
|
||||||
|
log.Debugf("current content piece: %s", contentPiece)
|
||||||
|
checkService := config.GetResponseCheckService(consumer)
|
||||||
|
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, contentPiece, sessionID)
|
||||||
|
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed call the safe check service: %v", err)
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
singleCall()
|
||||||
|
return types.ActionPause
|
||||||
|
}
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
package multi_modal_guard
|
||||||
|
|
||||||
|
import (
|
||||||
|
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||||
|
common_text "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text"
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image"
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
|
)
|
||||||
|
|
||||||
|
func OnHttpRequestHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
|
||||||
|
func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||||
|
return text.HandleTextGenerationRequestBody(ctx, config, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||||
|
switch config.ApiType {
|
||||||
|
case cfg.ApiTextGeneration:
|
||||||
|
return common_text.HandleTextGenerationResponseHeader(ctx, config)
|
||||||
|
case cfg.ApiImageGeneration:
|
||||||
|
switch config.ProviderType {
|
||||||
|
case cfg.ProviderOpenAI, cfg.ProviderQwen:
|
||||||
|
return image.HandleImageGenerationResponseHeader(ctx, config)
|
||||||
|
default:
|
||||||
|
log.Errorf("[on response header] image generation api don't support provider: %s", config.ProviderType)
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
log.Errorf("[on response header] multi_modal_guard don't support api: %s", config.ApiType)
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func OnHttpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
|
||||||
|
switch config.ApiType {
|
||||||
|
case cfg.ApiTextGeneration:
|
||||||
|
return common_text.HandleTextGenerationStreamingResponseBody(ctx, config, data, endOfStream)
|
||||||
|
default:
|
||||||
|
log.Errorf("[on streaming response body] multi_modal_guard don't support api: %s", config.ApiType)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func OnHttpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||||
|
switch config.ApiType {
|
||||||
|
case cfg.ApiTextGeneration:
|
||||||
|
return common_text.HandleTextGenerationResponseBody(ctx, config, body)
|
||||||
|
case cfg.ApiImageGeneration:
|
||||||
|
switch config.ProviderType {
|
||||||
|
case cfg.ProviderOpenAI:
|
||||||
|
return image.HandleOpenAIImageGenerationResponseBody(ctx, config, body)
|
||||||
|
case cfg.ProviderQwen:
|
||||||
|
return image.HandleQwenImageGenerationResponseBody(ctx, config, body)
|
||||||
|
default:
|
||||||
|
log.Errorf("[on response body] image generation api don't support provider: %s", config.ProviderType)
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
log.Errorf("[on response body] multi_modal_guard don't support api: %s", config.ApiType)
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package image
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
|
)
|
||||||
|
|
||||||
|
func HandleImageGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||||
|
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||||
|
ctx.SetContext("risk_detected", false)
|
||||||
|
if strings.Contains(contentType, "text/event-stream") {
|
||||||
|
ctx.DontReadResponseBody()
|
||||||
|
return types.ActionContinue
|
||||||
|
} else {
|
||||||
|
ctx.BufferResponseBody()
|
||||||
|
return types.HeaderStopIteration
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
package image
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ImageItemForOpenAI struct {
|
||||||
|
Content string
|
||||||
|
Type string // URL or BASE64
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOpenAIImageResults(body []byte) []ImageItemForOpenAI {
|
||||||
|
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
|
||||||
|
result := []ImageItemForOpenAI{}
|
||||||
|
for _, part := range gjson.GetBytes(body, "data").Array() {
|
||||||
|
if url := part.Get("url").String(); url != "" {
|
||||||
|
result = append(result, ImageItemForOpenAI{
|
||||||
|
Content: url,
|
||||||
|
Type: "URL",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if b64 := part.Get("b64_json").String(); b64 != "" {
|
||||||
|
result = append(result, ImageItemForOpenAI{
|
||||||
|
Content: b64,
|
||||||
|
Type: "BASE64",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||||
|
consumer, _ := ctx.GetContext("consumer").(string)
|
||||||
|
log.Debugf("checking response body...")
|
||||||
|
checkImageService := config.GetResponseImageCheckService(consumer)
|
||||||
|
startTime := time.Now().UnixMilli()
|
||||||
|
imgResults := getOpenAIImageResults(body)
|
||||||
|
if len(imgResults) == 0 {
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
imageIndex := 0
|
||||||
|
var singleCall func()
|
||||||
|
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
imageIndex += 1
|
||||||
|
log.Info(string(responseBody))
|
||||||
|
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||||
|
if imageIndex < len(imgResults) {
|
||||||
|
singleCall()
|
||||||
|
} else {
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var response cfg.Response
|
||||||
|
err := json.Unmarshal(responseBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("%+v", err)
|
||||||
|
if imageIndex < len(imgResults) {
|
||||||
|
singleCall()
|
||||||
|
} else {
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
endTime := time.Now().UnixMilli()
|
||||||
|
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||||
|
if imageIndex >= len(imgResults) {
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
} else {
|
||||||
|
singleCall()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte("illegal image"), -1)
|
||||||
|
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
}
|
||||||
|
singleCall = func() {
|
||||||
|
img := imgResults[imageIndex]
|
||||||
|
imgUrl := ""
|
||||||
|
imgBase64 := ""
|
||||||
|
if img.Type == "BASE64" {
|
||||||
|
imgBase64 = img.Content
|
||||||
|
} else {
|
||||||
|
imgUrl = img.Content
|
||||||
|
}
|
||||||
|
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
|
||||||
|
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed call the safe check service: %v", err)
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
singleCall()
|
||||||
|
return types.ActionPause
|
||||||
|
}
|
||||||
@@ -0,0 +1,134 @@
|
|||||||
|
package image
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getQwenImageUrls(body []byte) []string {
|
||||||
|
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
|
||||||
|
result := []string{}
|
||||||
|
// 文生图/文生图v1/文生图v2/通用图像编辑2.5/通用图像编辑2.1/涂鸦作画/图像局部重绘/人像风格重绘
|
||||||
|
// 虚拟模特/图像背景生成/人物写真FaceChain/文生图StableDiffusion/文生图FLUX/文字纹理生成API
|
||||||
|
for _, part := range gjson.GetBytes(body, "output.results").Array() {
|
||||||
|
if url := part.Get("url").String(); url != "" {
|
||||||
|
result = append(result, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 图像编辑
|
||||||
|
for _, part := range gjson.GetBytes(body, "output.choices.0.message.content").Array() {
|
||||||
|
if url := part.Get("image").String(); url != "" {
|
||||||
|
result = append(result, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 图像翻译/AI试衣OutfitAnyone
|
||||||
|
if url := gjson.GetBytes(body, "output.image_url").String(); url != "" {
|
||||||
|
result = append(result, url)
|
||||||
|
}
|
||||||
|
// 图像画面扩展/(part of)人物实例分割/图像擦除补全
|
||||||
|
if url := gjson.GetBytes(body, "output.output_image_url").String(); url != "" {
|
||||||
|
result = append(result, url)
|
||||||
|
}
|
||||||
|
// 鞋靴模特
|
||||||
|
if url := gjson.GetBytes(body, "output.result_url").String(); url != "" {
|
||||||
|
result = append(result, url)
|
||||||
|
}
|
||||||
|
// 创意海报生成
|
||||||
|
for _, part := range gjson.GetBytes(body, "output.render_urls").Array() {
|
||||||
|
if url := part.String(); url != "" {
|
||||||
|
result = append(result, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, part := range gjson.GetBytes(body, "output.bg_urls").Array() {
|
||||||
|
if url := part.String(); url != "" {
|
||||||
|
result = append(result, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 人物实例分割
|
||||||
|
if url := gjson.GetBytes(body, "output.output_vis_image_url").String(); url != "" {
|
||||||
|
result = append(result, url)
|
||||||
|
}
|
||||||
|
// 文字变形API
|
||||||
|
for _, part := range gjson.GetBytes(body, "output.results").Array() {
|
||||||
|
if url := part.Get("png_url").String(); url != "" {
|
||||||
|
result = append(result, url)
|
||||||
|
}
|
||||||
|
if url := part.Get("svg_url").String(); url != "" {
|
||||||
|
result = append(result, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||||
|
consumer, _ := ctx.GetContext("consumer").(string)
|
||||||
|
log.Debugf("checking response body...")
|
||||||
|
checkImageService := config.GetResponseImageCheckService(consumer)
|
||||||
|
startTime := time.Now().UnixMilli()
|
||||||
|
imgUrls := getQwenImageUrls(body)
|
||||||
|
if len(imgUrls) == 0 {
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
imageIndex := 0
|
||||||
|
var singleCall func()
|
||||||
|
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
imageIndex += 1
|
||||||
|
log.Info(string(responseBody))
|
||||||
|
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||||
|
if imageIndex < len(imgUrls) {
|
||||||
|
singleCall()
|
||||||
|
} else {
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var response cfg.Response
|
||||||
|
err := json.Unmarshal(responseBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("%+v", err)
|
||||||
|
if imageIndex < len(imgUrls) {
|
||||||
|
singleCall()
|
||||||
|
} else {
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
endTime := time.Now().UnixMilli()
|
||||||
|
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||||
|
if imageIndex >= len(imgUrls) {
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
} else {
|
||||||
|
singleCall()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte("illegal image"), -1)
|
||||||
|
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
}
|
||||||
|
singleCall = func() {
|
||||||
|
imgUrl := imgUrls[imageIndex]
|
||||||
|
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, "")
|
||||||
|
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed call the safe check service: %v", err)
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
singleCall()
|
||||||
|
return types.ActionPause
|
||||||
|
}
|
||||||
@@ -0,0 +1,231 @@
|
|||||||
|
package text
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ImageItemForOpenAI struct {
|
||||||
|
Content string
|
||||||
|
Type string // URL or BASE64
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseContent(json gjson.Result) (text string, images []ImageItemForOpenAI) {
|
||||||
|
images = []ImageItemForOpenAI{}
|
||||||
|
if json.IsArray() {
|
||||||
|
for _, item := range json.Array() {
|
||||||
|
switch item.Get("type").String() {
|
||||||
|
case "text":
|
||||||
|
text += item.Get("text").String()
|
||||||
|
case "image_url":
|
||||||
|
imgContent := item.Get("image_url.url").String()
|
||||||
|
if strings.HasPrefix(imgContent, "data:image") {
|
||||||
|
images = append(images, ImageItemForOpenAI{
|
||||||
|
Content: imgContent,
|
||||||
|
Type: "BASE64",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
images = append(images, ImageItemForOpenAI{
|
||||||
|
Content: imgContent,
|
||||||
|
Type: "URL",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
text = json.String()
|
||||||
|
}
|
||||||
|
return text, images
|
||||||
|
}
|
||||||
|
|
||||||
|
func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||||
|
consumer, _ := ctx.GetContext("consumer").(string)
|
||||||
|
checkService := config.GetRequestCheckService(consumer)
|
||||||
|
checkImageService := config.GetRequestImageCheckService(consumer)
|
||||||
|
startTime := time.Now().UnixMilli()
|
||||||
|
// content := gjson.GetBytes(body, config.RequestContentJsonPath).String()
|
||||||
|
content, images := parseContent(gjson.GetBytes(body, config.RequestContentJsonPath))
|
||||||
|
log.Debugf("Raw request content is: %s", content)
|
||||||
|
if len(content) == 0 && len(images) == 0 {
|
||||||
|
log.Info("request content is empty. skip")
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
contentIndex := 0
|
||||||
|
imageIndex := 0
|
||||||
|
sessionID, _ := utils.GenerateHexID(20)
|
||||||
|
var singleCall func()
|
||||||
|
var singleCallForImage func()
|
||||||
|
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Info(string(responseBody))
|
||||||
|
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var response cfg.Response
|
||||||
|
err := json.Unmarshal(responseBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("%+v", err)
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||||
|
if contentIndex >= len(content) {
|
||||||
|
endTime := time.Now().UnixMilli()
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
if len(images) > 0 && config.CheckRequestImage {
|
||||||
|
singleCallForImage()
|
||||||
|
} else {
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
singleCall()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
denyMessage := cfg.DefaultDenyMessage
|
||||||
|
if config.DenyMessage != "" {
|
||||||
|
denyMessage = config.DenyMessage
|
||||||
|
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||||
|
denyMessage = response.Data.Advice[0].Answer
|
||||||
|
}
|
||||||
|
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||||
|
if config.ProtocolOriginal {
|
||||||
|
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||||
|
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||||
|
randomID := utils.GenerateRandomChatID()
|
||||||
|
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||||
|
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||||
|
} else {
|
||||||
|
randomID := utils.GenerateRandomChatID()
|
||||||
|
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||||
|
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||||
|
}
|
||||||
|
ctx.DontReadResponseBody()
|
||||||
|
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||||
|
endTime := time.Now().UnixMilli()
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||||
|
if response.Data.Advice != nil {
|
||||||
|
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||||
|
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||||
|
}
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
}
|
||||||
|
singleCall = func() {
|
||||||
|
var nextContentIndex int
|
||||||
|
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||||
|
nextContentIndex = len(content)
|
||||||
|
} else {
|
||||||
|
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||||
|
}
|
||||||
|
contentPiece := content[contentIndex:nextContentIndex]
|
||||||
|
contentIndex = nextContentIndex
|
||||||
|
log.Debugf("current content piece: %s", contentPiece)
|
||||||
|
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)
|
||||||
|
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed call the safe check service: %v", err)
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
callbackForImage := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
imageIndex += 1
|
||||||
|
log.Info(string(responseBody))
|
||||||
|
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||||
|
if imageIndex < len(images) {
|
||||||
|
singleCallForImage()
|
||||||
|
} else {
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var response cfg.Response
|
||||||
|
err := json.Unmarshal(responseBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("%+v", err)
|
||||||
|
if imageIndex < len(images) {
|
||||||
|
singleCallForImage()
|
||||||
|
} else {
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
endTime := time.Now().UnixMilli()
|
||||||
|
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||||
|
if imageIndex >= len(images) {
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
} else {
|
||||||
|
singleCallForImage()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
denyMessage := cfg.DefaultDenyMessage
|
||||||
|
if config.DenyMessage != "" {
|
||||||
|
denyMessage = config.DenyMessage
|
||||||
|
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||||
|
denyMessage = response.Data.Advice[0].Answer
|
||||||
|
}
|
||||||
|
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||||
|
if config.ProtocolOriginal {
|
||||||
|
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||||
|
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||||
|
randomID := utils.GenerateRandomChatID()
|
||||||
|
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||||
|
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||||
|
} else {
|
||||||
|
randomID := utils.GenerateRandomChatID()
|
||||||
|
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||||
|
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||||
|
}
|
||||||
|
ctx.DontReadResponseBody()
|
||||||
|
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||||
|
if response.Data.Advice != nil {
|
||||||
|
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||||
|
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||||
|
}
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
}
|
||||||
|
singleCallForImage = func() {
|
||||||
|
img := images[imageIndex]
|
||||||
|
imgUrl := ""
|
||||||
|
imgBase64 := ""
|
||||||
|
if img.Type == "BASE64" {
|
||||||
|
imgBase64 = img.Content
|
||||||
|
} else {
|
||||||
|
imgUrl = img.Content
|
||||||
|
}
|
||||||
|
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
|
||||||
|
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed call the safe check service: %v", err)
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(content) > 0 {
|
||||||
|
singleCall()
|
||||||
|
} else {
|
||||||
|
singleCallForImage()
|
||||||
|
}
|
||||||
|
return types.ActionPause
|
||||||
|
}
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
package text_moderation_plus
|
||||||
|
|
||||||
|
import (
|
||||||
|
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||||
|
common_text "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text"
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus/text"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
|
)
|
||||||
|
|
||||||
|
func OnHttpRequestHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
|
||||||
|
func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||||
|
return text.HandleTextGenerationRequestBody(ctx, config, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||||
|
switch config.ApiType {
|
||||||
|
case cfg.ApiTextGeneration:
|
||||||
|
return common_text.HandleTextGenerationResponseHeader(ctx, config)
|
||||||
|
default:
|
||||||
|
log.Errorf("text_moderation_plus don't support api: %s", config.ApiType)
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func OnHttpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
|
||||||
|
switch config.ApiType {
|
||||||
|
case cfg.ApiTextGeneration:
|
||||||
|
return common_text.HandleTextGenerationStreamingResponseBody(ctx, config, data, endOfStream)
|
||||||
|
default:
|
||||||
|
log.Errorf("text_moderation_plus don't support api: %s", config.ApiType)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func OnHttpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||||
|
switch config.ApiType {
|
||||||
|
case cfg.ApiTextGeneration:
|
||||||
|
return common_text.HandleTextGenerationResponseBody(ctx, config, body)
|
||||||
|
default:
|
||||||
|
log.Errorf("text_moderation_plus don't support api: %s", config.ApiType)
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
package text
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||||
|
consumer, _ := ctx.GetContext("consumer").(string)
|
||||||
|
startTime := time.Now().UnixMilli()
|
||||||
|
content := gjson.GetBytes(body, config.RequestContentJsonPath).String()
|
||||||
|
log.Debugf("Raw request content is: %s", content)
|
||||||
|
if len(content) == 0 {
|
||||||
|
log.Info("request content is empty. skip")
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
contentIndex := 0
|
||||||
|
sessionID, _ := utils.GenerateHexID(20)
|
||||||
|
var singleCall func()
|
||||||
|
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Info(string(responseBody))
|
||||||
|
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var response cfg.Response
|
||||||
|
err := json.Unmarshal(responseBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("failed to unmarshal aliyun content security response at request phase")
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||||
|
if contentIndex >= len(content) {
|
||||||
|
endTime := time.Now().UnixMilli()
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
} else {
|
||||||
|
singleCall()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
denyMessage := cfg.DefaultDenyMessage
|
||||||
|
if config.DenyMessage != "" {
|
||||||
|
denyMessage = config.DenyMessage
|
||||||
|
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||||
|
denyMessage = response.Data.Advice[0].Answer
|
||||||
|
}
|
||||||
|
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||||
|
if config.ProtocolOriginal {
|
||||||
|
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||||
|
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||||
|
randomID := utils.GenerateRandomChatID()
|
||||||
|
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||||
|
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||||
|
} else {
|
||||||
|
randomID := utils.GenerateRandomChatID()
|
||||||
|
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||||
|
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||||
|
}
|
||||||
|
ctx.DontReadResponseBody()
|
||||||
|
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||||
|
endTime := time.Now().UnixMilli()
|
||||||
|
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||||
|
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||||
|
if response.Data.Advice != nil {
|
||||||
|
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||||
|
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||||
|
}
|
||||||
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
}
|
||||||
|
singleCall = func() {
|
||||||
|
var nextContentIndex int
|
||||||
|
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||||
|
nextContentIndex = len(content)
|
||||||
|
} else {
|
||||||
|
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||||
|
}
|
||||||
|
contentPiece := content[contentIndex:nextContentIndex]
|
||||||
|
contentIndex = nextContentIndex
|
||||||
|
checkService := config.GetRequestCheckService(consumer)
|
||||||
|
path, headers, body := common.GenerateRequestForText(config, cfg.TextModerationPlus, checkService, contentPiece, sessionID)
|
||||||
|
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed call the safe check service: %v", err)
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
singleCall()
|
||||||
|
return types.ActionPause
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -18,6 +18,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
"github.com/higress-group/wasm-go/pkg/test"
|
"github.com/higress-group/wasm-go/pkg/test"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -143,16 +145,16 @@ func TestParseConfig(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, config)
|
require.NotNil(t, config)
|
||||||
|
|
||||||
securityConfig := config.(*AISecurityConfig)
|
securityConfig := config.(*cfg.AISecurityConfig)
|
||||||
require.Equal(t, "test-ak", securityConfig.ak)
|
require.Equal(t, "test-ak", securityConfig.AK)
|
||||||
require.Equal(t, "test-sk", securityConfig.sk)
|
require.Equal(t, "test-sk", securityConfig.SK)
|
||||||
require.Equal(t, true, securityConfig.checkRequest)
|
require.Equal(t, true, securityConfig.CheckRequest)
|
||||||
require.Equal(t, true, securityConfig.checkResponse)
|
require.Equal(t, true, securityConfig.CheckResponse)
|
||||||
require.Equal(t, "high", securityConfig.contentModerationLevelBar)
|
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
|
||||||
require.Equal(t, "high", securityConfig.promptAttackLevelBar)
|
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
|
||||||
require.Equal(t, "S3", securityConfig.sensitiveDataLevelBar)
|
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
|
||||||
require.Equal(t, uint32(2000), securityConfig.timeout)
|
require.Equal(t, uint32(2000), securityConfig.Timeout)
|
||||||
require.Equal(t, 1000, securityConfig.bufferLimit)
|
require.Equal(t, 1000, securityConfig.BufferLimit)
|
||||||
})
|
})
|
||||||
|
|
||||||
// 测试仅检查请求的配置
|
// 测试仅检查请求的配置
|
||||||
@@ -164,12 +166,12 @@ func TestParseConfig(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, config)
|
require.NotNil(t, config)
|
||||||
|
|
||||||
securityConfig := config.(*AISecurityConfig)
|
securityConfig := config.(*cfg.AISecurityConfig)
|
||||||
require.Equal(t, true, securityConfig.checkRequest)
|
require.Equal(t, true, securityConfig.CheckRequest)
|
||||||
require.Equal(t, false, securityConfig.checkResponse)
|
require.Equal(t, false, securityConfig.CheckResponse)
|
||||||
require.Equal(t, "high", securityConfig.contentModerationLevelBar)
|
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
|
||||||
require.Equal(t, "high", securityConfig.promptAttackLevelBar)
|
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
|
||||||
require.Equal(t, "S3", securityConfig.sensitiveDataLevelBar)
|
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
|
||||||
})
|
})
|
||||||
|
|
||||||
// 测试缺少必需字段的配置
|
// 测试缺少必需字段的配置
|
||||||
@@ -202,13 +204,13 @@ func TestParseConfig(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, config)
|
require.NotNil(t, config)
|
||||||
|
|
||||||
securityConfig := config.(*AISecurityConfig)
|
securityConfig := config.(*cfg.AISecurityConfig)
|
||||||
require.Equal(t, "llm_query_moderation", securityConfig.getRequestCheckService("aaaa"))
|
require.Equal(t, "llm_query_moderation", securityConfig.GetRequestCheckService("aaaa"))
|
||||||
require.Equal(t, "llm_query_moderation_1", securityConfig.getRequestCheckService("aaa"))
|
require.Equal(t, "llm_query_moderation_1", securityConfig.GetRequestCheckService("aaa"))
|
||||||
require.Equal(t, "llm_response_moderation", securityConfig.getResponseCheckService("bb"))
|
require.Equal(t, "llm_response_moderation", securityConfig.GetResponseCheckService("bb"))
|
||||||
require.Equal(t, "llm_response_moderation_1", securityConfig.getResponseCheckService("bbb-prefix-test"))
|
require.Equal(t, "llm_response_moderation_1", securityConfig.GetResponseCheckService("bbb-prefix-test"))
|
||||||
require.Equal(t, "high", securityConfig.getMaliciousUrlLevelBar("cc"))
|
require.Equal(t, "high", securityConfig.GetMaliciousUrlLevelBar("cc"))
|
||||||
require.Equal(t, "low", securityConfig.getMaliciousUrlLevelBar("ccc-regexp-test"))
|
require.Equal(t, "low", securityConfig.GetMaliciousUrlLevelBar("ccc-regexp-test"))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -385,62 +387,27 @@ func TestOnHttpResponseHeaders(t *testing.T) {
|
|||||||
func TestRiskLevelFunctions(t *testing.T) {
|
func TestRiskLevelFunctions(t *testing.T) {
|
||||||
// 测试风险等级转换函数
|
// 测试风险等级转换函数
|
||||||
t.Run("risk level conversion", func(t *testing.T) {
|
t.Run("risk level conversion", func(t *testing.T) {
|
||||||
require.Equal(t, 4, levelToInt(MaxRisk))
|
require.Equal(t, 4, cfg.LevelToInt(cfg.MaxRisk))
|
||||||
require.Equal(t, 3, levelToInt(HighRisk))
|
require.Equal(t, 3, cfg.LevelToInt(cfg.HighRisk))
|
||||||
require.Equal(t, 2, levelToInt(MediumRisk))
|
require.Equal(t, 2, cfg.LevelToInt(cfg.MediumRisk))
|
||||||
require.Equal(t, 1, levelToInt(LowRisk))
|
require.Equal(t, 1, cfg.LevelToInt(cfg.LowRisk))
|
||||||
require.Equal(t, 0, levelToInt(NoRisk))
|
require.Equal(t, 0, cfg.LevelToInt(cfg.NoRisk))
|
||||||
require.Equal(t, -1, levelToInt("invalid"))
|
require.Equal(t, -1, cfg.LevelToInt("invalid"))
|
||||||
})
|
})
|
||||||
|
|
||||||
// 测试风险等级比较
|
// 测试风险等级比较
|
||||||
t.Run("risk level comparison", func(t *testing.T) {
|
t.Run("risk level comparison", func(t *testing.T) {
|
||||||
require.True(t, levelToInt(HighRisk) >= levelToInt(MediumRisk))
|
require.True(t, cfg.LevelToInt(cfg.HighRisk) >= cfg.LevelToInt(cfg.MediumRisk))
|
||||||
require.True(t, levelToInt(MediumRisk) >= levelToInt(LowRisk))
|
require.True(t, cfg.LevelToInt(cfg.MediumRisk) >= cfg.LevelToInt(cfg.LowRisk))
|
||||||
require.True(t, levelToInt(LowRisk) >= levelToInt(NoRisk))
|
require.True(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.NoRisk))
|
||||||
require.False(t, levelToInt(LowRisk) >= levelToInt(HighRisk))
|
require.False(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.HighRisk))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUtilityFunctions(t *testing.T) {
|
func TestUtilityFunctions(t *testing.T) {
|
||||||
// 测试URL编码函数
|
|
||||||
t.Run("url encoding", func(t *testing.T) {
|
|
||||||
original := "test+string:with=special&chars@$"
|
|
||||||
encoded := urlEncoding(original)
|
|
||||||
require.NotEqual(t, original, encoded)
|
|
||||||
require.Contains(t, encoded, "%2B") // + 应该被编码
|
|
||||||
require.Contains(t, encoded, "%3A") // : 应该被编码
|
|
||||||
require.Contains(t, encoded, "%3D") // = 应该被编码
|
|
||||||
require.Contains(t, encoded, "%26") // & 应该被编码
|
|
||||||
})
|
|
||||||
|
|
||||||
// 测试HMAC-SHA1签名函数
|
|
||||||
t.Run("hmac sha1", func(t *testing.T) {
|
|
||||||
message := "test message"
|
|
||||||
secret := "test secret"
|
|
||||||
signature := hmacSha1(message, secret)
|
|
||||||
require.NotEmpty(t, signature)
|
|
||||||
require.NotEqual(t, message, signature)
|
|
||||||
})
|
|
||||||
|
|
||||||
// 测试签名生成函数
|
|
||||||
t.Run("signature generation", func(t *testing.T) {
|
|
||||||
host, status := test.NewTestHost(basicConfig)
|
|
||||||
defer host.Reset()
|
|
||||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
||||||
|
|
||||||
params := map[string]string{
|
|
||||||
"key1": "value1",
|
|
||||||
"key2": "value2",
|
|
||||||
}
|
|
||||||
secret := "test-secret"
|
|
||||||
signature := getSign(params, secret)
|
|
||||||
require.NotEmpty(t, signature)
|
|
||||||
})
|
|
||||||
|
|
||||||
// 测试十六进制ID生成函数
|
// 测试十六进制ID生成函数
|
||||||
t.Run("hex id generation", func(t *testing.T) {
|
t.Run("hex id generation", func(t *testing.T) {
|
||||||
id, err := generateHexID(16)
|
id, err := utils.GenerateHexID(16)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, id, 16)
|
require.Len(t, id, 16)
|
||||||
require.Regexp(t, "^[0-9a-f]+$", id)
|
require.Regexp(t, "^[0-9a-f]+$", id)
|
||||||
@@ -448,7 +415,7 @@ func TestUtilityFunctions(t *testing.T) {
|
|||||||
|
|
||||||
// 测试随机ID生成函数
|
// 测试随机ID生成函数
|
||||||
t.Run("random id generation", func(t *testing.T) {
|
t.Run("random id generation", func(t *testing.T) {
|
||||||
id := generateRandomID()
|
id := utils.GenerateRandomChatID()
|
||||||
require.NotEmpty(t, id)
|
require.NotEmpty(t, id)
|
||||||
require.Contains(t, id, "chatcmpl-")
|
require.Contains(t, id, "chatcmpl-")
|
||||||
require.Len(t, id, 38) // "chatcmpl-" + 29 random chars
|
require.Len(t, id, 38) // "chatcmpl-" + 29 random chars
|
||||||
|
|||||||
43
plugins/wasm-go/extensions/ai-security-guard/utils/utils.go
Normal file
43
plugins/wasm-go/extensions/ai-security-guard/utils/utils.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
mrand "math/rand"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GenerateHexID(length int) (string, error) {
|
||||||
|
bytes := make([]byte, length/2)
|
||||||
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateRandomChatID() string {
|
||||||
|
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
|
b := make([]byte, 29)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = charset[mrand.Intn(len(charset))]
|
||||||
|
}
|
||||||
|
return "chatcmpl-" + string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtractMessageFromStreamingBody(data []byte, jsonPath string) string {
|
||||||
|
chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n"))
|
||||||
|
strChunks := []string{}
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
|
||||||
|
strChunks = append(strChunks, gjson.GetBytes(chunk, jsonPath).String())
|
||||||
|
}
|
||||||
|
return strings.Join(strChunks, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetConsumer(ctx wrapper.HttpContext) string {
|
||||||
|
return ctx.GetStringContext("consumer", "")
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user