From ce66ff68ce1a29e55616d8ee328bcf60f0215f88 Mon Sep 17 00:00:00 2001 From: rinfx <893383980@qq.com> Date: Thu, 5 Dec 2024 13:39:20 +0800 Subject: [PATCH] solve aliyun lvwang content length limit problem (#1569) --- .../extensions/ai-security-guard/go.mod | 2 +- .../extensions/ai-security-guard/go.sum | 4 +- .../extensions/ai-security-guard/main.go | 295 ++++++++++-------- 3 files changed, 171 insertions(+), 130 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-security-guard/go.mod b/plugins/wasm-go/extensions/ai-security-guard/go.mod index 3ab8ec183..c21172194 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/go.mod +++ b/plugins/wasm-go/extensions/ai-security-guard/go.mod @@ -6,7 +6,7 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../.. require ( github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906 - github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f + github.com/higress-group/proxy-wasm-go-sdk v1.0.0 github.com/tidwall/gjson v1.17.3 ) diff --git a/plugins/wasm-go/extensions/ai-security-guard/go.sum b/plugins/wasm-go/extensions/ai-security-guard/go.sum index 042eae70f..b4ab172fe 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/go.sum +++ b/plugins/wasm-go/extensions/ai-security-guard/go.sum @@ -3,8 +3,8 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU= +github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index cf7cdead0..02660fad8 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -55,6 +55,7 @@ const ( DefaultDenyMessage = "很抱歉,我无法回答您的问题" AliyunUserAgent = "CIPFrom/AIGateway" + LengthLimit = 1800 ) type Response struct { @@ -260,73 +261,92 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] log.Info("request content is empty. skip") return types.ActionContinue } - timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") - randomID, _ := generateHexID(16) - params := map[string]string{ - "Format": "JSON", - "Version": "2022-03-02", - "SignatureMethod": "Hmac-SHA1", - "SignatureNonce": randomID, - "SignatureVersion": "1.0", - "Action": "TextModerationPlus", - "AccessKeyId": config.ak, - "Timestamp": timestamp, - "Service": config.requestCheckService, - "ServiceParameters": fmt.Sprintf(`{"content": "%s"}`, marshalStr(content, log)), - } - if config.token != "" { - params["SecurityToken"] = config.token - } - signature := getSign(params, config.sk+"&") - reqParams := url.Values{} - for k, v := range params { - reqParams.Add(k, v) - } - reqParams.Add("Signature", signature) - err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, - func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Info(string(responseBody)) - if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + contentIndex := 0 + sessionID, _ := generateHexID(20) + var singleCall func() + callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Info(string(responseBody)) + if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + proxywasm.ResumeHttpRequest() + return + } + var response Response + err := json.Unmarshal(responseBody, &response) + if err != nil { + log.Error("failed to unmarshal aliyun content security response at request phase") + proxywasm.ResumeHttpRequest() + return + } + if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) { + if contentIndex >= len(content) { proxywasm.ResumeHttpRequest() - return - } - var response Response - err := json.Unmarshal(responseBody, &response) - if err != nil { - log.Error("failed to unmarshal aliyun content security response at request phase") - proxywasm.ResumeHttpRequest() - return - } - if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) { - proxywasm.ResumeHttpRequest() - return - } - denyMessage := DefaultDenyMessage - if config.denyMessage != "" { - denyMessage = config.denyMessage - } else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { - denyMessage = response.Data.Advice[0].Answer - } - marshalledDenyMessage := marshalStr(denyMessage, log) - if config.protocolOriginal { - proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) - } else if gjson.GetBytes(body, "stream").Bool() { - randomID := generateRandomID() - jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model)) - proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) } else { - randomID := generateRandomID() - jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage)) - proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) + singleCall() } - ctx.DontReadResponseBody() - config.incrementCounter("ai_sec_request_deny", 1) - }, - ) - if err != nil { - log.Errorf("failed call the safe check service: %v", err) - return types.ActionContinue + return + } + denyMessage := DefaultDenyMessage + if config.denyMessage != "" { + denyMessage = config.denyMessage + } else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { + denyMessage = response.Data.Advice[0].Answer + } + marshalledDenyMessage := marshalStr(denyMessage, log) + if config.protocolOriginal { + proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) + } else if gjson.GetBytes(body, "stream").Bool() { + randomID := generateRandomID() + jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model)) + proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) + } else { + randomID := generateRandomID() + jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage)) + proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) + } + ctx.DontReadResponseBody() + config.incrementCounter("ai_sec_request_deny", 1) + proxywasm.ResumeHttpRequest() } + singleCall = func() { + timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") + randomID, _ := generateHexID(16) + var nextContentIndex int + if contentIndex+LengthLimit >= len(content) { + nextContentIndex = len(content) + } else { + nextContentIndex = contentIndex + LengthLimit + } + contentPiece := content[contentIndex:nextContentIndex] + contentIndex = nextContentIndex + log.Debugf("current content piece: %s", contentPiece) + params := map[string]string{ + "Format": "JSON", + "Version": "2022-03-02", + "SignatureMethod": "Hmac-SHA1", + "SignatureNonce": randomID, + "SignatureVersion": "1.0", + "Action": "TextModerationPlus", + "AccessKeyId": config.ak, + "Timestamp": timestamp, + "Service": config.requestCheckService, + "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece, log)), + } + if config.token != "" { + params["SecurityToken"] = config.token + } + signature := getSign(params, config.sk+"&") + reqParams := url.Values{} + for k, v := range params { + reqParams.Add(k, v) + } + reqParams.Add("Signature", signature) + err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback) + if err != nil { + log.Errorf("failed call the safe check service: %v", err) + proxywasm.ResumeHttpRequest() + } + } + singleCall() return types.ActionPause } @@ -385,73 +405,94 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ log.Info("response content is empty. skip") return types.ActionContinue } - timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") - randomID, _ := generateHexID(16) - params := map[string]string{ - "Format": "JSON", - "Version": "2022-03-02", - "SignatureMethod": "Hmac-SHA1", - "SignatureNonce": randomID, - "SignatureVersion": "1.0", - "Action": "TextModerationPlus", - "AccessKeyId": config.ak, - "Timestamp": timestamp, - "Service": config.responseCheckService, - "ServiceParameters": fmt.Sprintf(`{"content": "%s"}`, marshalStr(content, log)), - } - if config.token != "" { - params["SecurityToken"] = config.token - } - signature := getSign(params, config.sk+"&") - reqParams := url.Values{} - for k, v := range params { - reqParams.Add(k, v) - } - reqParams.Add("Signature", signature) - err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, - func(statusCode int, responseHeaders http.Header, responseBody []byte) { - defer proxywasm.ResumeHttpResponse() - log.Info(string(responseBody)) - if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { - return - } - var response Response - err := json.Unmarshal(responseBody, &response) - if err != nil { - log.Error("failed to unmarshal aliyun content security response at response phase") - return - } - if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) { - return - } - denyMessage := DefaultDenyMessage - if config.denyMessage != "" { - denyMessage = config.denyMessage - } else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { - denyMessage = response.Data.Advice[0].Answer - } - marshalledDenyMessage := marshalStr(denyMessage, log) - var jsonData []byte - if config.protocolOriginal { - jsonData = []byte(marshalledDenyMessage) - } else if isStreamingResponse { - randomID := generateRandomID() - jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model)) + contentIndex := 0 + sessionID, _ := generateHexID(20) + var singleCall func() + callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Info(string(responseBody)) + if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + proxywasm.ResumeHttpResponse() + return + } + var response Response + err := json.Unmarshal(responseBody, &response) + if err != nil { + log.Error("failed to unmarshal aliyun content security response at response phase") + proxywasm.ResumeHttpResponse() + return + } + if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) { + if contentIndex >= len(content) { + proxywasm.ResumeHttpResponse() } else { - randomID := generateRandomID() - jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage)) + singleCall() } - delete(hdsMap, "content-length") - hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)} - proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap)) - proxywasm.ReplaceHttpResponseBody(jsonData) - config.incrementCounter("ai_sec_response_deny", 1) - }, - ) - if err != nil { - log.Errorf("failed call the safe check service: %v", err) - return types.ActionContinue + return + } + denyMessage := DefaultDenyMessage + if config.denyMessage != "" { + denyMessage = config.denyMessage + } else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { + denyMessage = response.Data.Advice[0].Answer + } + marshalledDenyMessage := marshalStr(denyMessage, log) + var jsonData []byte + if config.protocolOriginal { + jsonData = []byte(marshalledDenyMessage) + } else if isStreamingResponse { + randomID := generateRandomID() + jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model)) + } else { + randomID := generateRandomID() + jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage)) + } + delete(hdsMap, "content-length") + hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)} + proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap)) + proxywasm.ReplaceHttpResponseBody(jsonData) + config.incrementCounter("ai_sec_response_deny", 1) + proxywasm.ResumeHttpResponse() } + singleCall = func() { + timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") + randomID, _ := generateHexID(16) + var nextContentIndex int + if contentIndex+LengthLimit >= len(content) { + nextContentIndex = len(content) + } else { + nextContentIndex = contentIndex + LengthLimit + } + contentPiece := content[contentIndex:nextContentIndex] + contentIndex = nextContentIndex + log.Debugf("current content piece: %s", contentPiece) + params := map[string]string{ + "Format": "JSON", + "Version": "2022-03-02", + "SignatureMethod": "Hmac-SHA1", + "SignatureNonce": randomID, + "SignatureVersion": "1.0", + "Action": "TextModerationPlus", + "AccessKeyId": config.ak, + "Timestamp": timestamp, + "Service": config.responseCheckService, + "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece, log)), + } + if config.token != "" { + params["SecurityToken"] = config.token + } + signature := getSign(params, config.sk+"&") + reqParams := url.Values{} + for k, v := range params { + reqParams.Add(k, v) + } + reqParams.Add("Signature", signature) + err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback) + if err != nil { + log.Errorf("failed call the safe check service: %v", err) + proxywasm.ResumeHttpResponse() + } + } + singleCall() return types.ActionPause }