From 58ffdae6bae31fa302a4220e8fbf099bbbb97247 Mon Sep 17 00:00:00 2001 From: JianweiWang Date: Tue, 2 Jun 2026 12:11:00 +0800 Subject: [PATCH] feat(ai-security-guard): add Embedding API content detection support (#3895) Signed-off-by: root --- .../ai-security-guard/config/config.go | 7 +- .../multi_modal_guard/embedding/openai.go | 333 +++++++++++++ .../lvwang/multi_modal_guard/handler.go | 11 + .../extensions/ai-security-guard/main.go | 4 +- .../extensions/ai-security-guard/main_test.go | 465 ++++++++++++++++++ 5 files changed, 818 insertions(+), 2 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/embedding/openai.go diff --git a/plugins/wasm-go/extensions/ai-security-guard/config/config.go b/plugins/wasm-go/extensions/ai-security-guard/config/config.go index 9f5ac09e1..ceac4395e 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/config/config.go +++ b/plugins/wasm-go/extensions/ai-security-guard/config/config.go @@ -118,6 +118,7 @@ const ( ApiTextGeneration = "text_generation" ApiImageGeneration = "image_generation" ApiMCP = "mcp" + ApiEmbedding = "embedding" ) // provider types @@ -206,6 +207,7 @@ type AISecurityConfig struct { ResponseStreamContentJsonPath string ResponseContentFallbackJsonPaths []string ResponseStreamContentFallbackJsonPaths []string + ResponseErrorContentJsonPath string DenyCode int64 DenyMessage string ProtocolOriginal bool @@ -223,7 +225,7 @@ type AISecurityConfig struct { ConsumerRequestCheckService []map[string]interface{} ConsumerResponseCheckService []map[string]interface{} ConsumerRiskLevel []map[string]interface{} - // text_generation, image_generation, etc. + // text_generation, image_generation, embedding, etc. ApiType string // openai, qwen, comfyui, etc. ProviderType string @@ -355,6 +357,9 @@ func (config *AISecurityConfig) Parse(json gjson.Result) error { } else if exists { config.ResponseStreamContentFallbackJsonPaths = paths } + if obj := json.Get("responseErrorContentJsonPath"); obj.Exists() { + config.ResponseErrorContentJsonPath = obj.String() + } if obj := json.Get("contentModerationLevelBar"); obj.Exists() { config.ContentModerationLevelBar = obj.String() if LevelToInt(config.ContentModerationLevelBar) <= 0 { diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/embedding/openai.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/embedding/openai.go new file mode 100644 index 000000000..d1706236c --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/embedding/openai.go @@ -0,0 +1,333 @@ +package embedding + +import ( + "encoding/json" + "fmt" + "net/http" + "time" + + cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +// OpenAI Embedding error response format +const EmbeddingErrorResponseFormat = `{"error": {"message": "%s", "type": "invalid_request_error", "param": null, "code": "content_policy_violation"}}` + +// parseInput extracts text from the input field of an Embedding request. +// input can be: +// - A string: returns the string directly +// - An array of strings: returns all strings joined +// - An array of integers (token IDs): returns empty with unsupportedType=true +func parseInput(json gjson.Result) (text string, unsupportedType bool) { + if json.IsArray() { + // Check if it's an array of strings or token IDs + arr := json.Array() + if len(arr) == 0 { + return "", false + } + + // Check first element type + if arr[0].Type == gjson.String { + // Array of strings + var texts []string + for _, item := range arr { + if item.Type == gjson.String { + texts = append(texts, item.String()) + } + } + return joinTexts(texts), false + } else if arr[0].Type == gjson.Number { + // Array of token IDs - not supported for text detection + log.Info("embedding input is token ID array, not supported for text detection") + return "", true + } + } else if json.Type == gjson.String { + // Single string + return json.String(), false + } + + // Unknown type + log.Warnf("embedding input has unsupported type: %v", json.Type) + return "", true +} + +// joinTexts joins multiple text strings with newline separator +func joinTexts(texts []string) string { + result := "" + for i, t := range texts { + if i > 0 { + result += "\n" + } + result += t + } + return result +} + +// structuralFields contains field names that should be skipped when extracting text content +// These are structural/metadata fields, not user content +var structuralFields = map[string]bool{ + "object": true, // JSON structure identifier + "model": true, // Model name + "index": true, // Array index marker + "encoding": true, // Encoding format + "id": true, // Response ID + "requestId": true, // Request ID +} + +// extractStringLeaves recursively extracts string values from a JSON structure +// Skips structural/metadata fields that are not user content +func extractStringLeaves(json gjson.Result, texts *[]string) { + if json.Type == gjson.String { + *texts = append(*texts, json.String()) + return + } + + if json.IsArray() { + for _, item := range json.Array() { + extractStringLeaves(item, texts) + } + return + } + + if json.IsObject() { + json.ForEach(func(key, value gjson.Result) bool { + // Skip structural/metadata fields + if structuralFields[key.String()] { + return true + } + // Skip embedding vectors (numeric arrays or base64 strings) + if key.String() == "embedding" { + return true + } + extractStringLeaves(value, texts) + return true + }) + } +} + +// HandleEmbeddingRequestBody handles request body for Embedding API +func HandleEmbeddingRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action { + consumer, _ := ctx.GetContext("consumer").(string) + checkService := config.GetRequestCheckService(consumer) + startTime := time.Now().UnixMilli() + + // Extract text from input field + input := gjson.GetBytes(body, config.RequestContentJsonPath) + content, unsupportedType := parseInput(input) + + log.Debugf("Embedding request content: %s, unsupportedType: %v", content, unsupportedType) + + // Handle unsupported input types (e.g., token ID arrays) + if unsupportedType { + log.Info("embedding request has unsupported input type, skipping text detection") + ctx.SetUserAttribute("safecheck_status", "request skip - unsupported input type") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + return types.ActionContinue + } + + if len(content) == 0 { + log.Info("embedding request content is empty, skip") + return types.ActionContinue + } + + contentIndex := 0 + sessionID, _ := utils.GenerateHexID(20) + var singleCall func() + + callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Info(string(responseBody)) + if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + proxywasm.ResumeHttpRequest() + return + } + + var response cfg.Response + err := json.Unmarshal(responseBody, &response) + if err != nil { + log.Errorf("%+v", err) + proxywasm.ResumeHttpRequest() + return + } + + if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) { + if contentIndex >= len(content) { + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "request pass") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + proxywasm.ResumeHttpRequest() + } else { + singleCall() + } + return + } + + // Risk detected - send Embedding-compatible error response + denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) + if err != nil { + log.Errorf("failed to build deny response body: %v", err) + proxywasm.ResumeHttpRequest() + return + } + + // Use Embedding-specific error response format + marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) + jsonData := []byte(fmt.Sprintf(EmbeddingErrorResponseFormat, marshalledDenyMessage)) + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) + + ctx.DontReadResponseBody() + config.IncrementCounter("ai_sec_request_deny", 1) + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "request deny") + if response.Data.Advice != nil { + ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) + ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) + } + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + + singleCall = func() { + var nextContentIndex int + if contentIndex+cfg.LengthLimit >= len(content) { + nextContentIndex = len(content) + } else { + nextContentIndex = contentIndex + cfg.LengthLimit + } + contentPiece := content[contentIndex:nextContentIndex] + contentIndex = nextContentIndex + log.Debugf("current content piece: %s", contentPiece) + path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID) + err := config.Client.Post(path, headers, body, callback, config.Timeout) + if err != nil { + log.Errorf("failed call the safe check service: %v", err) + proxywasm.ResumeHttpRequest() + } + } + + singleCall() + return types.ActionPause +} + +// HandleEmbeddingResponseHeaders handles response headers for Embedding API +func HandleEmbeddingResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action { + ctx.BufferResponseBody() + return types.HeaderStopIteration +} + +// HandleEmbeddingResponseBody handles response body for Embedding API +func HandleEmbeddingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action { + consumer, _ := ctx.GetContext("consumer").(string) + log.Debugf("checking embedding response body...") + startTime := time.Now().UnixMilli() + + // Priority 1: Check error.message for error responses + var content string + if config.ResponseErrorContentJsonPath != "" { + content = gjson.GetBytes(body, config.ResponseErrorContentJsonPath).String() + } + + // Priority 2: Extract string leaves from data field + if len(content) == 0 { + data := gjson.GetBytes(body, config.ResponseContentJsonPath) + var texts []string + extractStringLeaves(data, &texts) + if len(texts) > 0 { + content = joinTexts(texts) + } + } + + log.Debugf("Embedding response content length: %d", len(content)) + + if len(content) == 0 { + // No text found - this is normal for standard embedding responses that only contain vectors + log.Info("embedding response has no text content (likely vector-only response), skipping text detection") + ctx.SetUserAttribute("safecheck_status", "response skip - no text content") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + return types.ActionContinue + } + + contentIndex := 0 + sessionID, _ := utils.GenerateHexID(20) + var singleCall func() + + callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Info(string(responseBody)) + if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + proxywasm.ResumeHttpResponse() + return + } + + var response cfg.Response + err := json.Unmarshal(responseBody, &response) + if err != nil { + log.Error("failed to unmarshal aliyun content security response at response phase") + proxywasm.ResumeHttpResponse() + return + } + + if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) { + if contentIndex >= len(content) { + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "response pass") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + proxywasm.ResumeHttpResponse() + } else { + singleCall() + } + return + } + + // Risk detected - send Embedding-compatible error response + denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer) + if err != nil { + log.Errorf("failed to build deny response body: %v", err) + proxywasm.ResumeHttpResponse() + return + } + + // Use Embedding-specific error response format + marshalledDenyMessage := wrapper.MarshalStr(string(denyBody)) + jsonData := []byte(fmt.Sprintf(EmbeddingErrorResponseFormat, marshalledDenyMessage)) + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) + + config.IncrementCounter("ai_sec_response_deny", 1) + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "response deny") + if response.Data.Advice != nil { + ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) + ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) + } + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + + singleCall = func() { + var nextContentIndex int + if contentIndex+cfg.LengthLimit >= len(content) { + nextContentIndex = len(content) + } else { + nextContentIndex = contentIndex + cfg.LengthLimit + } + contentPiece := content[contentIndex:nextContentIndex] + contentIndex = nextContentIndex + log.Debugf("current content piece: %s", contentPiece) + checkService := config.GetResponseCheckService(consumer) + path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, contentPiece, sessionID) + err := config.Client.Post(path, headers, body, callback, config.Timeout) + if err != nil { + log.Errorf("failed call the safe check service: %v", err) + proxywasm.ResumeHttpResponse() + } + } + + singleCall() + return types.ActionPause +} \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/handler.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/handler.go index 98b06f2ee..07f2fcbc7 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/handler.go +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/handler.go @@ -3,6 +3,7 @@ package multi_modal_guard import ( cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" common_text "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/embedding" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/mcp" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text" @@ -31,6 +32,8 @@ func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, bod } case cfg.ApiMCP: return mcp.HandleMcpRequestBody(ctx, config, body) + case cfg.ApiEmbedding: + return embedding.HandleEmbeddingRequestBody(ctx, config, body) default: log.Errorf("[on request body] multi_modal_guard don't support api: %s", config.ApiType) return types.ActionContinue @@ -58,6 +61,8 @@ func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) ctx.NeedPauseStreamingResponse() return types.ActionContinue } + case cfg.ApiEmbedding: + return embedding.HandleEmbeddingResponseHeaders(ctx, config) default: log.Errorf("[on response header] multi_modal_guard don't support api: %s", config.ApiType) return types.ActionContinue @@ -70,6 +75,10 @@ func OnHttpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityC return common_text.HandleTextGenerationStreamingResponseBody(ctx, config, data, endOfStream) case cfg.ApiMCP: return mcp.HandleMcpStreamingResponseBody(ctx, config, data, endOfStream) + case cfg.ApiEmbedding: + // Embedding doesn't support streaming responses; pass through and log warning + log.Warnf("[on streaming response body] embedding api doesn't support streaming, ignoring responseStreamContentJsonPath") + return data default: log.Errorf("[on streaming response body] multi_modal_guard don't support api: %s", config.ApiType) return data @@ -92,6 +101,8 @@ func OnHttpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, bo } case cfg.ApiMCP: return mcp.HandleMcpResponseBody(ctx, config, body) + case cfg.ApiEmbedding: + return embedding.HandleEmbeddingResponseBody(ctx, config, body) default: log.Errorf("[on response body] multi_modal_guard don't support api: %s", config.ApiType) return types.ActionContinue diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index eb23af06c..57a5ae32f 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -61,7 +61,9 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) return types.ActionContinue } statusCode, _ := proxywasm.GetHttpResponseHeader(":status") - if statusCode != "200" { + // For embedding API, we need to check error.message in non-200 responses + // so we don't skip response body check for embedding apiType + if statusCode != "200" && config.ApiType != cfg.ApiEmbedding { log.Debugf("response is not 200, skip response body check") ctx.DontReadResponseBody() return types.ActionContinue diff --git a/plugins/wasm-go/extensions/ai-security-guard/main_test.go b/plugins/wasm-go/extensions/ai-security-guard/main_test.go index 9afe34485..f3d8415e9 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main_test.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main_test.go @@ -335,6 +335,72 @@ func mustDecodeLegacyDenyContent(t *testing.T, content string) cfg.DenyResponseB return denyBody } +// 测试配置:Embedding API +var embeddingConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "checkResponse": true, + "action": "MultiModalGuard", + "apiType": "embedding", + "requestContentJsonPath": "input", + "responseContentJsonPath": "data", + "responseErrorContentJsonPath": "error.message", + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "sensitiveDataLevelBar": "S3", + "timeout": 2000, + }) + return data +}() + +// 测试配置:Embedding API 仅请求检测 +var embeddingRequestOnlyConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "checkResponse": false, + "action": "MultiModalGuard", + "apiType": "embedding", + "requestContentJsonPath": "input", + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "sensitiveDataLevelBar": "S3", + "timeout": 2000, + }) + return data +}() + +// 测试配置:Embedding API 仅响应检测 +var embeddingResponseOnlyConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": false, + "checkResponse": true, + "action": "MultiModalGuard", + "apiType": "embedding", + "responseContentJsonPath": "data", + "responseErrorContentJsonPath": "error.message", + "contentModerationLevelBar": "high", + "promptAttackLevelBar": "high", + "sensitiveDataLevelBar": "S3", + "timeout": 2000, + }) + return data +}() + func TestParseConfig(t *testing.T) { test.RunGoTest(t, func(t *testing.T) { // 测试基础配置解析 @@ -4211,3 +4277,402 @@ func TestTextModerationPlusRequestDenyGuardrailShape(t *testing.T) { }) }) } + +func TestEmbeddingConfig(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("embedding config with responseErrorContentJsonPath", func(t *testing.T) { + host, status := test.NewTestHost(embeddingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + + securityConfig := config.(*cfg.AISecurityConfig) + require.Equal(t, "embedding", securityConfig.ApiType) + require.Equal(t, "input", securityConfig.RequestContentJsonPath) + require.Equal(t, "data", securityConfig.ResponseContentJsonPath) + require.Equal(t, "error.message", securityConfig.ResponseErrorContentJsonPath) + require.Equal(t, true, securityConfig.CheckRequest) + require.Equal(t, true, securityConfig.CheckResponse) + }) + + t.Run("embedding config without responseErrorContentJsonPath", func(t *testing.T) { + // Test backward compatibility when responseErrorContentJsonPath is not provided + configWithoutErrorPath := func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "serviceName": "security-service", + "servicePort": 8080, + "serviceHost": "security.example.com", + "accessKey": "test-ak", + "secretKey": "test-sk", + "checkRequest": true, + "checkResponse": true, + "action": "MultiModalGuard", + "apiType": "embedding", + "requestContentJsonPath": "input", + "responseContentJsonPath": "data", + "contentModerationLevelBar": "high", + "timeout": 2000, + }) + return data + }() + host, status := test.NewTestHost(configWithoutErrorPath) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + + securityConfig := config.(*cfg.AISecurityConfig) + require.Equal(t, "embedding", securityConfig.ApiType) + require.Equal(t, "", securityConfig.ResponseErrorContentJsonPath) + }) + }) +} + +func TestEmbeddingRequest(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("embedding request with string input pass", func(t *testing.T) { + host, status := test.NewTestHost(embeddingRequestOnlyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/embeddings"}, + {":method", "POST"}, + }) + + body := `{"input": "Hello, how are you?", "model": "text-embedding-ada-002"}` + action := host.CallOnHttpRequestBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-pass", "Data": {"RiskLevel": "low"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + action = host.GetHttpStreamAction() + require.Equal(t, types.ActionContinue, action) + host.CompleteHttp() + }) + + t.Run("embedding request with string array input pass", func(t *testing.T) { + host, status := test.NewTestHost(embeddingRequestOnlyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/embeddings"}, + {":method", "POST"}, + }) + + body := `{"input": ["Hello", "World"], "model": "text-embedding-ada-002"}` + action := host.CallOnHttpRequestBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-array-pass", "Data": {"RiskLevel": "low"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + action = host.GetHttpStreamAction() + require.Equal(t, types.ActionContinue, action) + host.CompleteHttp() + }) + + t.Run("embedding request with token ID array skip", func(t *testing.T) { + host, status := test.NewTestHost(embeddingRequestOnlyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/embeddings"}, + {":method", "POST"}, + }) + + // Token ID array input - should skip detection + body := `{"input": [1234, 5678, 9012], "model": "text-embedding-ada-002"}` + action := host.CallOnHttpRequestBody([]byte(body)) + // Should continue without checking (unsupported input type) + require.Equal(t, types.ActionContinue, action) + }) + + t.Run("embedding request deny with embedding error format", func(t *testing.T) { + host, status := test.NewTestHost(embeddingRequestOnlyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/embeddings"}, + {":method", "POST"}, + }) + + body := `{"input": "bad content", "model": "text-embedding-ada-002"}` + action := host.CallOnHttpRequestBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-deny", "Data": {"RiskLevel": "high"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local, "expected SendHttpResponse for Embedding request deny") + // Verify the response uses Embedding error format + var errorResp map[string]interface{} + require.NoError(t, json.Unmarshal(local.Data, &errorResp)) + require.Contains(t, errorResp, "error") + errorObj := errorResp["error"].(map[string]interface{}) + require.Contains(t, errorObj, "message") + require.Contains(t, errorObj, "type") + require.Contains(t, errorObj, "code") + }) + }) +} + +func TestEmbeddingResponse(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("embedding response with error message", func(t *testing.T) { + host, status := test.NewTestHost(embeddingResponseOnlyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/embeddings"}, + {":method", "POST"}, + }) + + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }) + + // Response with error.message field + body := `{"error": {"message": "Rate limit exceeded", "type": "rate_limit_error"}, "data": []}` + action := host.CallOnHttpResponseBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-resp-error", "Data": {"RiskLevel": "low"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + action = host.GetHttpStreamAction() + require.Equal(t, types.ActionContinue, action) + host.CompleteHttp() + }) + + t.Run("embedding response vector only skip", func(t *testing.T) { + host, status := test.NewTestHost(embeddingResponseOnlyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/embeddings"}, + {":method", "POST"}, + }) + + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }) + + // Standard embedding response with only vectors - no text content + body := `{ + "object": "list", + "data": [ + {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}, + {"object": "embedding", "index": 1, "embedding": [0.4, 0.5, 0.6]} + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 10, "total_tokens": 10} + }` + action := host.CallOnHttpResponseBody([]byte(body)) + // Should skip since no text content + require.Equal(t, types.ActionContinue, action) + }) + + t.Run("embedding response base64 vector skip", func(t *testing.T) { + host, status := test.NewTestHost(embeddingResponseOnlyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/embeddings"}, + {":method", "POST"}, + }) + + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }) + + // Embedding response with base64 encoding_format - embedding is a string, not an array + body := `{ + "object": "list", + "data": [ + {"object": "embedding", "index": 0, "embedding": "AGC3PAAAtzzAQLc8gEC3PEBAtzy"} + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 10, "total_tokens": 10} + }` + action := host.CallOnHttpResponseBody([]byte(body)) + // Should skip since base64 embedding strings are not user content + require.Equal(t, types.ActionContinue, action) + }) + + t.Run("embedding response deny with embedding error format", func(t *testing.T) { + host, status := test.NewTestHost(embeddingResponseOnlyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/embeddings"}, + {":method", "POST"}, + }) + + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }) + + // Response with text content in error.message + body := `{"error": {"message": "bad response content"}, "data": []}` + action := host.CallOnHttpResponseBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-resp-deny", "Data": {"RiskLevel": "high"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + local := host.GetLocalResponse() + require.NotNil(t, local, "expected SendHttpResponse for Embedding response deny") + // Verify the response uses Embedding error format + var errorResp map[string]interface{} + require.NoError(t, json.Unmarshal(local.Data, &errorResp)) + require.Contains(t, errorResp, "error") + errorObj := errorResp["error"].(map[string]interface{}) + require.Contains(t, errorObj, "message") + require.Contains(t, errorObj, "type") + require.Contains(t, errorObj, "code") + }) + }) +} + +func TestEmbeddingStreamingIgnored(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("embedding streaming response ignores responseStreamContentJsonPath", func(t *testing.T) { + host, status := test.NewTestHost(embeddingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/embeddings"}, + {":method", "POST"}, + }) + + // Simulate streaming response headers + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }) + + // Even if streaming content path is set, embedding should process non-streaming + body := `{ + "object": "list", + "data": [{"object": "embedding", "index": 0, "embedding": [0.1, 0.2]}], + "model": "text-embedding-ada-002" + }` + action := host.CallOnHttpResponseBody([]byte(body)) + // Should continue since there's no text content + require.Equal(t, types.ActionContinue, action) + }) + }) +} + +func TestEmbeddingNon200Response(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("embedding API should check response body for non-200 status", func(t *testing.T) { + // Embedding API with responseErrorContentJsonPath should check error.message + // even when status code is not 200 + host, status := test.NewTestHost(embeddingResponseOnlyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/embeddings"}, + {":method", "POST"}, + }) + + // Non-200 response (e.g., 400 Bad Request) + // For embedding API, response body should be buffered for later processing + // HandleEmbeddingResponseHeaders returns HeaderStopIteration (ActionPause) + action := host.CallOnHttpResponseHeaders([][2]string{ + {":status", "400"}, + {"content-type", "application/json"}, + }) + // HeaderStopIteration = ActionPause indicates body will be buffered and processed + require.Equal(t, types.HeaderStopIteration, action) + + // Response body with error.message should be checked + body := `{"error": {"message": "Invalid input: sensitive content detected", "type": "invalid_request_error"}}` + action = host.CallOnHttpResponseBody([]byte(body)) + require.Equal(t, types.ActionPause, action) + + // Simulate security service response with high risk + securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-embed-non200", "Data": {"RiskLevel": "high"}}` + host.CallOnHttpCall([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }, []byte(securityResponse)) + + // Verify deny response was sent (Embedding error format) + local := host.GetLocalResponse() + require.NotNil(t, local, "expected SendHttpResponse for Embedding deny") + var errorResp map[string]interface{} + require.NoError(t, json.Unmarshal(local.Data, &errorResp)) + require.Contains(t, errorResp, "error") + errorObj := errorResp["error"].(map[string]interface{}) + require.Contains(t, errorObj, "message") + }) + + t.Run("non-embedding API should skip response body for non-200 status", func(t *testing.T) { + // Non-embedding API should maintain existing behavior: skip response body + // for non-200 responses + host, status := test.NewTestHost(multiModalGuardTextConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + }) + + // Non-200 response + action := host.CallOnHttpResponseHeaders([][2]string{ + {":status", "500"}, + {"content-type", "application/json"}, + }) + // For non-embedding API, should skip response body check + require.Equal(t, types.ActionContinue, action) + }) + }) +}