mirror of
https://github.com/alibaba/higress.git
synced 2026-06-05 18:57:30 +08:00
feat(ai-security-guard): add Embedding API content detection support (#3895)
Signed-off-by: root <jianwei.wjw@alibaba-inc.com>
This commit is contained in:
@@ -118,6 +118,7 @@ const (
|
||||
ApiTextGeneration = "text_generation"
|
||||
ApiImageGeneration = "image_generation"
|
||||
ApiMCP = "mcp"
|
||||
ApiEmbedding = "embedding"
|
||||
)
|
||||
|
||||
// provider types
|
||||
@@ -206,6 +207,7 @@ type AISecurityConfig struct {
|
||||
ResponseStreamContentJsonPath string
|
||||
ResponseContentFallbackJsonPaths []string
|
||||
ResponseStreamContentFallbackJsonPaths []string
|
||||
ResponseErrorContentJsonPath string
|
||||
DenyCode int64
|
||||
DenyMessage string
|
||||
ProtocolOriginal bool
|
||||
@@ -223,7 +225,7 @@ type AISecurityConfig struct {
|
||||
ConsumerRequestCheckService []map[string]interface{}
|
||||
ConsumerResponseCheckService []map[string]interface{}
|
||||
ConsumerRiskLevel []map[string]interface{}
|
||||
// text_generation, image_generation, etc.
|
||||
// text_generation, image_generation, embedding, etc.
|
||||
ApiType string
|
||||
// openai, qwen, comfyui, etc.
|
||||
ProviderType string
|
||||
@@ -355,6 +357,9 @@ func (config *AISecurityConfig) Parse(json gjson.Result) error {
|
||||
} else if exists {
|
||||
config.ResponseStreamContentFallbackJsonPaths = paths
|
||||
}
|
||||
if obj := json.Get("responseErrorContentJsonPath"); obj.Exists() {
|
||||
config.ResponseErrorContentJsonPath = obj.String()
|
||||
}
|
||||
if obj := json.Get("contentModerationLevelBar"); obj.Exists() {
|
||||
config.ContentModerationLevelBar = obj.String()
|
||||
if LevelToInt(config.ContentModerationLevelBar) <= 0 {
|
||||
|
||||
@@ -0,0 +1,333 @@
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// OpenAI Embedding error response format
|
||||
const EmbeddingErrorResponseFormat = `{"error": {"message": "%s", "type": "invalid_request_error", "param": null, "code": "content_policy_violation"}}`
|
||||
|
||||
// parseInput extracts text from the input field of an Embedding request.
|
||||
// input can be:
|
||||
// - A string: returns the string directly
|
||||
// - An array of strings: returns all strings joined
|
||||
// - An array of integers (token IDs): returns empty with unsupportedType=true
|
||||
func parseInput(json gjson.Result) (text string, unsupportedType bool) {
|
||||
if json.IsArray() {
|
||||
// Check if it's an array of strings or token IDs
|
||||
arr := json.Array()
|
||||
if len(arr) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Check first element type
|
||||
if arr[0].Type == gjson.String {
|
||||
// Array of strings
|
||||
var texts []string
|
||||
for _, item := range arr {
|
||||
if item.Type == gjson.String {
|
||||
texts = append(texts, item.String())
|
||||
}
|
||||
}
|
||||
return joinTexts(texts), false
|
||||
} else if arr[0].Type == gjson.Number {
|
||||
// Array of token IDs - not supported for text detection
|
||||
log.Info("embedding input is token ID array, not supported for text detection")
|
||||
return "", true
|
||||
}
|
||||
} else if json.Type == gjson.String {
|
||||
// Single string
|
||||
return json.String(), false
|
||||
}
|
||||
|
||||
// Unknown type
|
||||
log.Warnf("embedding input has unsupported type: %v", json.Type)
|
||||
return "", true
|
||||
}
|
||||
|
||||
// joinTexts joins multiple text strings with newline separator
|
||||
func joinTexts(texts []string) string {
|
||||
result := ""
|
||||
for i, t := range texts {
|
||||
if i > 0 {
|
||||
result += "\n"
|
||||
}
|
||||
result += t
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// structuralFields contains field names that should be skipped when extracting text content
|
||||
// These are structural/metadata fields, not user content
|
||||
var structuralFields = map[string]bool{
|
||||
"object": true, // JSON structure identifier
|
||||
"model": true, // Model name
|
||||
"index": true, // Array index marker
|
||||
"encoding": true, // Encoding format
|
||||
"id": true, // Response ID
|
||||
"requestId": true, // Request ID
|
||||
}
|
||||
|
||||
// extractStringLeaves recursively extracts string values from a JSON structure
|
||||
// Skips structural/metadata fields that are not user content
|
||||
func extractStringLeaves(json gjson.Result, texts *[]string) {
|
||||
if json.Type == gjson.String {
|
||||
*texts = append(*texts, json.String())
|
||||
return
|
||||
}
|
||||
|
||||
if json.IsArray() {
|
||||
for _, item := range json.Array() {
|
||||
extractStringLeaves(item, texts)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if json.IsObject() {
|
||||
json.ForEach(func(key, value gjson.Result) bool {
|
||||
// Skip structural/metadata fields
|
||||
if structuralFields[key.String()] {
|
||||
return true
|
||||
}
|
||||
// Skip embedding vectors (numeric arrays or base64 strings)
|
||||
if key.String() == "embedding" {
|
||||
return true
|
||||
}
|
||||
extractStringLeaves(value, texts)
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HandleEmbeddingRequestBody handles request body for Embedding API
|
||||
func HandleEmbeddingRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
checkService := config.GetRequestCheckService(consumer)
|
||||
startTime := time.Now().UnixMilli()
|
||||
|
||||
// Extract text from input field
|
||||
input := gjson.GetBytes(body, config.RequestContentJsonPath)
|
||||
content, unsupportedType := parseInput(input)
|
||||
|
||||
log.Debugf("Embedding request content: %s, unsupportedType: %v", content, unsupportedType)
|
||||
|
||||
// Handle unsupported input types (e.g., token ID arrays)
|
||||
if unsupportedType {
|
||||
log.Info("embedding request has unsupported input type, skipping text detection")
|
||||
ctx.SetUserAttribute("safecheck_status", "request skip - unsupported input type")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
if len(content) == 0 {
|
||||
log.Info("embedding request content is empty, skip")
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
contentIndex := 0
|
||||
sessionID, _ := utils.GenerateHexID(20)
|
||||
var singleCall func()
|
||||
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if contentIndex >= len(content) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Risk detected - send Embedding-compatible error response
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
|
||||
// Use Embedding-specific error response format
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
jsonData := []byte(fmt.Sprintf(EmbeddingErrorResponseFormat, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
|
||||
singleCall = func() {
|
||||
var nextContentIndex int
|
||||
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||
nextContentIndex = len(content)
|
||||
} else {
|
||||
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||
}
|
||||
contentPiece := content[contentIndex:nextContentIndex]
|
||||
contentIndex = nextContentIndex
|
||||
log.Debugf("current content piece: %s", contentPiece)
|
||||
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}
|
||||
|
||||
singleCall()
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
// HandleEmbeddingResponseHeaders handles response headers for Embedding API
|
||||
func HandleEmbeddingResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||
ctx.BufferResponseBody()
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
// HandleEmbeddingResponseBody handles response body for Embedding API
|
||||
func HandleEmbeddingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
log.Debugf("checking embedding response body...")
|
||||
startTime := time.Now().UnixMilli()
|
||||
|
||||
// Priority 1: Check error.message for error responses
|
||||
var content string
|
||||
if config.ResponseErrorContentJsonPath != "" {
|
||||
content = gjson.GetBytes(body, config.ResponseErrorContentJsonPath).String()
|
||||
}
|
||||
|
||||
// Priority 2: Extract string leaves from data field
|
||||
if len(content) == 0 {
|
||||
data := gjson.GetBytes(body, config.ResponseContentJsonPath)
|
||||
var texts []string
|
||||
extractStringLeaves(data, &texts)
|
||||
if len(texts) > 0 {
|
||||
content = joinTexts(texts)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("Embedding response content length: %d", len(content))
|
||||
|
||||
if len(content) == 0 {
|
||||
// No text found - this is normal for standard embedding responses that only contain vectors
|
||||
log.Info("embedding response has no text content (likely vector-only response), skipping text detection")
|
||||
ctx.SetUserAttribute("safecheck_status", "response skip - no text content")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
contentIndex := 0
|
||||
sessionID, _ := utils.GenerateHexID(20)
|
||||
var singleCall func()
|
||||
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
return
|
||||
}
|
||||
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Error("failed to unmarshal aliyun content security response at response phase")
|
||||
proxywasm.ResumeHttpResponse()
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if contentIndex >= len(content) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "response pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Risk detected - send Embedding-compatible error response
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
return
|
||||
}
|
||||
|
||||
// Use Embedding-specific error response format
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
jsonData := []byte(fmt.Sprintf(EmbeddingErrorResponseFormat, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
|
||||
config.IncrementCounter("ai_sec_response_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "response deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
|
||||
singleCall = func() {
|
||||
var nextContentIndex int
|
||||
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||
nextContentIndex = len(content)
|
||||
} else {
|
||||
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||
}
|
||||
contentPiece := content[contentIndex:nextContentIndex]
|
||||
contentIndex = nextContentIndex
|
||||
log.Debugf("current content piece: %s", contentPiece)
|
||||
checkService := config.GetResponseCheckService(consumer)
|
||||
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, contentPiece, sessionID)
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
}
|
||||
|
||||
singleCall()
|
||||
return types.ActionPause
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package multi_modal_guard
|
||||
import (
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
common_text "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/embedding"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/mcp"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text"
|
||||
@@ -31,6 +32,8 @@ func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, bod
|
||||
}
|
||||
case cfg.ApiMCP:
|
||||
return mcp.HandleMcpRequestBody(ctx, config, body)
|
||||
case cfg.ApiEmbedding:
|
||||
return embedding.HandleEmbeddingRequestBody(ctx, config, body)
|
||||
default:
|
||||
log.Errorf("[on request body] multi_modal_guard don't support api: %s", config.ApiType)
|
||||
return types.ActionContinue
|
||||
@@ -58,6 +61,8 @@ func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig)
|
||||
ctx.NeedPauseStreamingResponse()
|
||||
return types.ActionContinue
|
||||
}
|
||||
case cfg.ApiEmbedding:
|
||||
return embedding.HandleEmbeddingResponseHeaders(ctx, config)
|
||||
default:
|
||||
log.Errorf("[on response header] multi_modal_guard don't support api: %s", config.ApiType)
|
||||
return types.ActionContinue
|
||||
@@ -70,6 +75,10 @@ func OnHttpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityC
|
||||
return common_text.HandleTextGenerationStreamingResponseBody(ctx, config, data, endOfStream)
|
||||
case cfg.ApiMCP:
|
||||
return mcp.HandleMcpStreamingResponseBody(ctx, config, data, endOfStream)
|
||||
case cfg.ApiEmbedding:
|
||||
// Embedding doesn't support streaming responses; pass through and log warning
|
||||
log.Warnf("[on streaming response body] embedding api doesn't support streaming, ignoring responseStreamContentJsonPath")
|
||||
return data
|
||||
default:
|
||||
log.Errorf("[on streaming response body] multi_modal_guard don't support api: %s", config.ApiType)
|
||||
return data
|
||||
@@ -92,6 +101,8 @@ func OnHttpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, bo
|
||||
}
|
||||
case cfg.ApiMCP:
|
||||
return mcp.HandleMcpResponseBody(ctx, config, body)
|
||||
case cfg.ApiEmbedding:
|
||||
return embedding.HandleEmbeddingResponseBody(ctx, config, body)
|
||||
default:
|
||||
log.Errorf("[on response body] multi_modal_guard don't support api: %s", config.ApiType)
|
||||
return types.ActionContinue
|
||||
|
||||
@@ -61,7 +61,9 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig)
|
||||
return types.ActionContinue
|
||||
}
|
||||
statusCode, _ := proxywasm.GetHttpResponseHeader(":status")
|
||||
if statusCode != "200" {
|
||||
// For embedding API, we need to check error.message in non-200 responses
|
||||
// so we don't skip response body check for embedding apiType
|
||||
if statusCode != "200" && config.ApiType != cfg.ApiEmbedding {
|
||||
log.Debugf("response is not 200, skip response body check")
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
|
||||
@@ -335,6 +335,72 @@ func mustDecodeLegacyDenyContent(t *testing.T, content string) cfg.DenyResponseB
|
||||
return denyBody
|
||||
}
|
||||
|
||||
// 测试配置:Embedding API
|
||||
var embeddingConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"serviceName": "security-service",
|
||||
"servicePort": 8080,
|
||||
"serviceHost": "security.example.com",
|
||||
"accessKey": "test-ak",
|
||||
"secretKey": "test-sk",
|
||||
"checkRequest": true,
|
||||
"checkResponse": true,
|
||||
"action": "MultiModalGuard",
|
||||
"apiType": "embedding",
|
||||
"requestContentJsonPath": "input",
|
||||
"responseContentJsonPath": "data",
|
||||
"responseErrorContentJsonPath": "error.message",
|
||||
"contentModerationLevelBar": "high",
|
||||
"promptAttackLevelBar": "high",
|
||||
"sensitiveDataLevelBar": "S3",
|
||||
"timeout": 2000,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Embedding API 仅请求检测
|
||||
var embeddingRequestOnlyConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"serviceName": "security-service",
|
||||
"servicePort": 8080,
|
||||
"serviceHost": "security.example.com",
|
||||
"accessKey": "test-ak",
|
||||
"secretKey": "test-sk",
|
||||
"checkRequest": true,
|
||||
"checkResponse": false,
|
||||
"action": "MultiModalGuard",
|
||||
"apiType": "embedding",
|
||||
"requestContentJsonPath": "input",
|
||||
"contentModerationLevelBar": "high",
|
||||
"promptAttackLevelBar": "high",
|
||||
"sensitiveDataLevelBar": "S3",
|
||||
"timeout": 2000,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Embedding API 仅响应检测
|
||||
var embeddingResponseOnlyConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"serviceName": "security-service",
|
||||
"servicePort": 8080,
|
||||
"serviceHost": "security.example.com",
|
||||
"accessKey": "test-ak",
|
||||
"secretKey": "test-sk",
|
||||
"checkRequest": false,
|
||||
"checkResponse": true,
|
||||
"action": "MultiModalGuard",
|
||||
"apiType": "embedding",
|
||||
"responseContentJsonPath": "data",
|
||||
"responseErrorContentJsonPath": "error.message",
|
||||
"contentModerationLevelBar": "high",
|
||||
"promptAttackLevelBar": "high",
|
||||
"sensitiveDataLevelBar": "S3",
|
||||
"timeout": 2000,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试基础配置解析
|
||||
@@ -4211,3 +4277,402 @@ func TestTextModerationPlusRequestDenyGuardrailShape(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddingConfig(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("embedding config with responseErrorContentJsonPath", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(embeddingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
securityConfig := config.(*cfg.AISecurityConfig)
|
||||
require.Equal(t, "embedding", securityConfig.ApiType)
|
||||
require.Equal(t, "input", securityConfig.RequestContentJsonPath)
|
||||
require.Equal(t, "data", securityConfig.ResponseContentJsonPath)
|
||||
require.Equal(t, "error.message", securityConfig.ResponseErrorContentJsonPath)
|
||||
require.Equal(t, true, securityConfig.CheckRequest)
|
||||
require.Equal(t, true, securityConfig.CheckResponse)
|
||||
})
|
||||
|
||||
t.Run("embedding config without responseErrorContentJsonPath", func(t *testing.T) {
|
||||
// Test backward compatibility when responseErrorContentJsonPath is not provided
|
||||
configWithoutErrorPath := func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"serviceName": "security-service",
|
||||
"servicePort": 8080,
|
||||
"serviceHost": "security.example.com",
|
||||
"accessKey": "test-ak",
|
||||
"secretKey": "test-sk",
|
||||
"checkRequest": true,
|
||||
"checkResponse": true,
|
||||
"action": "MultiModalGuard",
|
||||
"apiType": "embedding",
|
||||
"requestContentJsonPath": "input",
|
||||
"responseContentJsonPath": "data",
|
||||
"contentModerationLevelBar": "high",
|
||||
"timeout": 2000,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
host, status := test.NewTestHost(configWithoutErrorPath)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
securityConfig := config.(*cfg.AISecurityConfig)
|
||||
require.Equal(t, "embedding", securityConfig.ApiType)
|
||||
require.Equal(t, "", securityConfig.ResponseErrorContentJsonPath)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddingRequest(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("embedding request with string input pass", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(embeddingRequestOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
body := `{"input": "Hello, how are you?", "model": "text-embedding-ada-002"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-pass", "Data": {"RiskLevel": "low"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
action = host.GetHttpStreamAction()
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
t.Run("embedding request with string array input pass", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(embeddingRequestOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
body := `{"input": ["Hello", "World"], "model": "text-embedding-ada-002"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-array-pass", "Data": {"RiskLevel": "low"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
action = host.GetHttpStreamAction()
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
t.Run("embedding request with token ID array skip", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(embeddingRequestOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
// Token ID array input - should skip detection
|
||||
body := `{"input": [1234, 5678, 9012], "model": "text-embedding-ada-002"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
// Should continue without checking (unsupported input type)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
t.Run("embedding request deny with embedding error format", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(embeddingRequestOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
body := `{"input": "bad content", "model": "text-embedding-ada-002"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-deny", "Data": {"RiskLevel": "high"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
local := host.GetLocalResponse()
|
||||
require.NotNil(t, local, "expected SendHttpResponse for Embedding request deny")
|
||||
// Verify the response uses Embedding error format
|
||||
var errorResp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(local.Data, &errorResp))
|
||||
require.Contains(t, errorResp, "error")
|
||||
errorObj := errorResp["error"].(map[string]interface{})
|
||||
require.Contains(t, errorObj, "message")
|
||||
require.Contains(t, errorObj, "type")
|
||||
require.Contains(t, errorObj, "code")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddingResponse(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("embedding response with error message", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(embeddingResponseOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// Response with error.message field
|
||||
body := `{"error": {"message": "Rate limit exceeded", "type": "rate_limit_error"}, "data": []}`
|
||||
action := host.CallOnHttpResponseBody([]byte(body))
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-resp-error", "Data": {"RiskLevel": "low"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
action = host.GetHttpStreamAction()
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
t.Run("embedding response vector only skip", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(embeddingResponseOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// Standard embedding response with only vectors - no text content
|
||||
body := `{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]},
|
||||
{"object": "embedding", "index": 1, "embedding": [0.4, 0.5, 0.6]}
|
||||
],
|
||||
"model": "text-embedding-ada-002",
|
||||
"usage": {"prompt_tokens": 10, "total_tokens": 10}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(body))
|
||||
// Should skip since no text content
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
t.Run("embedding response base64 vector skip", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(embeddingResponseOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// Embedding response with base64 encoding_format - embedding is a string, not an array
|
||||
body := `{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{"object": "embedding", "index": 0, "embedding": "AGC3PAAAtzzAQLc8gEC3PEBAtzy"}
|
||||
],
|
||||
"model": "text-embedding-ada-002",
|
||||
"usage": {"prompt_tokens": 10, "total_tokens": 10}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(body))
|
||||
// Should skip since base64 embedding strings are not user content
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
t.Run("embedding response deny with embedding error format", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(embeddingResponseOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// Response with text content in error.message
|
||||
body := `{"error": {"message": "bad response content"}, "data": []}`
|
||||
action := host.CallOnHttpResponseBody([]byte(body))
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-resp-deny", "Data": {"RiskLevel": "high"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
local := host.GetLocalResponse()
|
||||
require.NotNil(t, local, "expected SendHttpResponse for Embedding response deny")
|
||||
// Verify the response uses Embedding error format
|
||||
var errorResp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(local.Data, &errorResp))
|
||||
require.Contains(t, errorResp, "error")
|
||||
errorObj := errorResp["error"].(map[string]interface{})
|
||||
require.Contains(t, errorObj, "message")
|
||||
require.Contains(t, errorObj, "type")
|
||||
require.Contains(t, errorObj, "code")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddingStreamingIgnored(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("embedding streaming response ignores responseStreamContentJsonPath", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(embeddingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
// Simulate streaming response headers
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// Even if streaming content path is set, embedding should process non-streaming
|
||||
body := `{
|
||||
"object": "list",
|
||||
"data": [{"object": "embedding", "index": 0, "embedding": [0.1, 0.2]}],
|
||||
"model": "text-embedding-ada-002"
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(body))
|
||||
// Should continue since there's no text content
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddingNon200Response(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("embedding API should check response body for non-200 status", func(t *testing.T) {
|
||||
// Embedding API with responseErrorContentJsonPath should check error.message
|
||||
// even when status code is not 200
|
||||
host, status := test.NewTestHost(embeddingResponseOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
// Non-200 response (e.g., 400 Bad Request)
|
||||
// For embedding API, response body should be buffered for later processing
|
||||
// HandleEmbeddingResponseHeaders returns HeaderStopIteration (ActionPause)
|
||||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "400"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
// HeaderStopIteration = ActionPause indicates body will be buffered and processed
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// Response body with error.message should be checked
|
||||
body := `{"error": {"message": "Invalid input: sensitive content detected", "type": "invalid_request_error"}}`
|
||||
action = host.CallOnHttpResponseBody([]byte(body))
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// Simulate security service response with high risk
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-non200", "Data": {"RiskLevel": "high"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
// Verify deny response was sent (Embedding error format)
|
||||
local := host.GetLocalResponse()
|
||||
require.NotNil(t, local, "expected SendHttpResponse for Embedding deny")
|
||||
var errorResp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(local.Data, &errorResp))
|
||||
require.Contains(t, errorResp, "error")
|
||||
errorObj := errorResp["error"].(map[string]interface{})
|
||||
require.Contains(t, errorObj, "message")
|
||||
})
|
||||
|
||||
t.Run("non-embedding API should skip response body for non-200 status", func(t *testing.T) {
|
||||
// Non-embedding API should maintain existing behavior: skip response body
|
||||
// for non-200 responses
|
||||
host, status := test.NewTestHost(multiModalGuardTextConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
// Non-200 response
|
||||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "500"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
// For non-embedding API, should skip response body check
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user