Ai data masking msg window (#1775)

This commit is contained in:
007gzs
2025-02-26 20:48:37 +08:00
committed by GitHub
parent 9ea2410388
commit 2d8a8f26da
2 changed files with 402 additions and 69 deletions

View File

@@ -13,8 +13,10 @@
// limitations under the License.
mod deny_word;
mod msg_window;
use crate::deny_word::DenyWord;
use crate::msg_window::MsgWindow;
use fancy_regex::Regex;
use grok::patterns;
use higress_wasm_rust::log::Log;
@@ -27,8 +29,8 @@ use proxy_wasm::traits::{Context, HttpContext, RootContext};
use proxy_wasm::types::{Bytes, ContextType, DataAction, HeaderAction, LogLevel};
use rust_embed::Embed;
use serde::de::Error;
use serde::Deserialize;
use serde::Deserializer;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap, VecDeque};
@@ -66,9 +68,12 @@ struct AiDataMasking {
config: Option<Rc<AiDataMaskingConfig>>,
mask_map: HashMap<String, Option<String>>,
is_openai: bool,
is_openai_stream: Option<bool>,
stream: bool,
res_body: Bytes,
log: Log,
msg_window: MsgWindow,
char_window_size: usize,
byte_window_size: usize,
}
fn deserialize_regexp<'de, D>(deserializer: D) -> Result<Regex, D::Error>
where
@@ -213,10 +218,33 @@ struct ResMessage {
#[serde(default)]
delta: Option<Message>,
}
#[derive(Default, Debug, Deserialize, Serialize, Clone)]
struct Usage {
completion_tokens: i32,
prompt_tokens: i32,
total_tokens: i32,
}
impl Usage {
pub fn add(&mut self, usage: &Usage) {
self.completion_tokens += usage.completion_tokens;
self.prompt_tokens += usage.prompt_tokens;
self.total_tokens += usage.total_tokens;
}
pub fn reset(&mut self) {
self.completion_tokens = 0;
self.prompt_tokens = 0;
self.total_tokens = 0;
}
}
#[derive(Default, Debug, Deserialize)]
struct Res {
#[serde(default)]
choices: Vec<ResMessage>,
#[serde(default)]
usage: Usage,
}
static SYSTEM_PATTERNS: &[(&str, &str)] = &[
@@ -334,9 +362,12 @@ impl RootContextWrapper<AiDataMaskingConfig> for AiDataMaskingRoot {
mask_map: HashMap::new(),
config: None,
is_openai: false,
is_openai_stream: None,
stream: false,
res_body: Bytes::new(),
msg_window: MsgWindow::new(),
log: Log::new(PLUGIN_NAME.to_string()),
char_window_size: 0,
byte_window_size: 0,
}))
}
}
@@ -416,32 +447,6 @@ impl AiDataMasking {
DataAction::StopIterationAndBuffer
}
fn process_sse_message(&mut self, sse_message: &str) -> Vec<String> {
let mut messages = Vec::new();
for msg in sse_message.split('\n') {
if !msg.starts_with("data:") {
continue;
}
let res: Res = if let Some(m) = msg.strip_prefix("data:") {
match serde_json::from_str(m) {
Ok(r) => r,
Err(_) => continue,
}
} else {
continue;
};
if res.choices.is_empty() {
continue;
}
for choice in &res.choices {
if let Some(delta) = &choice.delta {
messages.push(delta.content.clone());
}
}
}
messages
}
fn replace_request_msg(&mut self, message: &str) -> String {
let config = self.config.as_ref().unwrap();
let mut msg = message.to_string();
@@ -464,6 +469,13 @@ impl AiDataMasking {
}
Type::Replace => rule.regex.replace(from_word, &rule.value).to_string(),
};
if to_word.len() > self.byte_window_size {
self.byte_window_size = to_word.len();
}
if to_word.chars().count() > self.char_window_size {
self.char_window_size = to_word.chars().count();
}
replace_pair.push((from_word.to_string(), to_word.clone()));
if rule.restore && !to_word.is_empty() {
@@ -499,6 +511,7 @@ impl HttpContext for AiDataMasking {
_end_of_stream: bool,
) -> HeaderAction {
if has_request_body() {
self.set_http_request_header("Content-Length", None);
HeaderAction::StopIteration
} else {
HeaderAction::Continue
@@ -512,58 +525,41 @@ impl HttpContext for AiDataMasking {
self.set_http_response_header("Content-Length", None);
HeaderAction::Continue
}
fn on_http_response_body(&mut self, body_size: usize, _end_of_stream: bool) -> DataAction {
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> DataAction {
if !self.stream {
return DataAction::Continue;
}
if let Some(body) = self.get_http_response_body(0, body_size) {
self.res_body.extend(&body);
if let Ok(body_str) = String::from_utf8(self.res_body.clone()) {
if self.is_openai {
let messages = self.process_sse_message(&body_str);
if self.check_message(&messages.join("")) {
if body_size > 0 {
if let Some(body) = self.get_http_response_body(0, body_size) {
if self.is_openai && self.is_openai_stream.is_none() {
self.is_openai_stream = Some(body.starts_with(b"data:"));
}
self.msg_window.push(&body, self.is_openai_stream.unwrap());
if let Ok(mut msg) = String::from_utf8(self.msg_window.message.clone()) {
if self.check_message(&msg) {
return self.deny(true);
}
} else if self.check_message(&body_str) {
return self.deny(true);
}
}
if self.mask_map.is_empty() {
return DataAction::Continue;
}
if let Ok(body_str) = std::str::from_utf8(&body) {
let mut new_str = body_str.to_string();
if self.is_openai {
let messages = self.process_sse_message(body_str);
for message in messages {
let mut new_message = message.clone();
if !self.mask_map.is_empty() {
for (from_word, to_word) in self.mask_map.iter() {
if let Some(to) = to_word {
new_message = new_message.replace(from_word, to);
msg = msg.replace(from_word, to);
}
}
if new_message != message {
new_str = new_str.replace(
&json!(message).to_string(),
&json!(new_message).to_string(),
);
}
}
} else {
for (from_word, to_word) in self.mask_map.iter() {
if let Some(to) = to_word {
new_str = new_str.replace(from_word, to);
}
}
}
if new_str != body_str {
self.replace_http_response_body(new_str.as_bytes());
self.msg_window.message = msg.as_bytes().to_vec();
}
}
}
let new_body = if end_of_stream {
self.msg_window.finish(self.is_openai_stream.unwrap())
} else {
self.msg_window.pop(
self.char_window_size * 2,
self.byte_window_size * 2,
self.is_openai_stream.unwrap(),
)
};
self.replace_http_response_body(&new_body);
DataAction::Continue
}
}
@@ -586,7 +582,6 @@ impl HttpContextWrapper<AiDataMaskingConfig> for AiDataMasking {
return DataAction::Continue;
}
let config = self.config.as_ref().unwrap();
let mut req_body = match String::from_utf8(req_body.clone()) {
Ok(r) => r,
Err(_) => return DataAction::Continue,