mirror of
https://github.com/alibaba/higress.git
synced 2026-05-27 06:07:27 +08:00
Ai data masking msg window (#1775)
This commit is contained in:
@@ -13,8 +13,10 @@
|
||||
// limitations under the License.
|
||||
|
||||
mod deny_word;
|
||||
mod msg_window;
|
||||
|
||||
use crate::deny_word::DenyWord;
|
||||
use crate::msg_window::MsgWindow;
|
||||
use fancy_regex::Regex;
|
||||
use grok::patterns;
|
||||
use higress_wasm_rust::log::Log;
|
||||
@@ -27,8 +29,8 @@ use proxy_wasm::traits::{Context, HttpContext, RootContext};
|
||||
use proxy_wasm::types::{Bytes, ContextType, DataAction, HeaderAction, LogLevel};
|
||||
use rust_embed::Embed;
|
||||
use serde::de::Error;
|
||||
use serde::Deserialize;
|
||||
use serde::Deserializer;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::{BTreeMap, HashMap, VecDeque};
|
||||
@@ -66,9 +68,12 @@ struct AiDataMasking {
|
||||
config: Option<Rc<AiDataMaskingConfig>>,
|
||||
mask_map: HashMap<String, Option<String>>,
|
||||
is_openai: bool,
|
||||
is_openai_stream: Option<bool>,
|
||||
stream: bool,
|
||||
res_body: Bytes,
|
||||
log: Log,
|
||||
msg_window: MsgWindow,
|
||||
char_window_size: usize,
|
||||
byte_window_size: usize,
|
||||
}
|
||||
fn deserialize_regexp<'de, D>(deserializer: D) -> Result<Regex, D::Error>
|
||||
where
|
||||
@@ -213,10 +218,33 @@ struct ResMessage {
|
||||
#[serde(default)]
|
||||
delta: Option<Message>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Deserialize, Serialize, Clone)]
|
||||
struct Usage {
|
||||
completion_tokens: i32,
|
||||
prompt_tokens: i32,
|
||||
total_tokens: i32,
|
||||
}
|
||||
|
||||
impl Usage {
|
||||
pub fn add(&mut self, usage: &Usage) {
|
||||
self.completion_tokens += usage.completion_tokens;
|
||||
self.prompt_tokens += usage.prompt_tokens;
|
||||
self.total_tokens += usage.total_tokens;
|
||||
}
|
||||
pub fn reset(&mut self) {
|
||||
self.completion_tokens = 0;
|
||||
self.prompt_tokens = 0;
|
||||
self.total_tokens = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Deserialize)]
|
||||
struct Res {
|
||||
#[serde(default)]
|
||||
choices: Vec<ResMessage>,
|
||||
#[serde(default)]
|
||||
usage: Usage,
|
||||
}
|
||||
|
||||
static SYSTEM_PATTERNS: &[(&str, &str)] = &[
|
||||
@@ -334,9 +362,12 @@ impl RootContextWrapper<AiDataMaskingConfig> for AiDataMaskingRoot {
|
||||
mask_map: HashMap::new(),
|
||||
config: None,
|
||||
is_openai: false,
|
||||
is_openai_stream: None,
|
||||
stream: false,
|
||||
res_body: Bytes::new(),
|
||||
msg_window: MsgWindow::new(),
|
||||
log: Log::new(PLUGIN_NAME.to_string()),
|
||||
char_window_size: 0,
|
||||
byte_window_size: 0,
|
||||
}))
|
||||
}
|
||||
}
|
||||
@@ -416,32 +447,6 @@ impl AiDataMasking {
|
||||
DataAction::StopIterationAndBuffer
|
||||
}
|
||||
|
||||
fn process_sse_message(&mut self, sse_message: &str) -> Vec<String> {
|
||||
let mut messages = Vec::new();
|
||||
for msg in sse_message.split('\n') {
|
||||
if !msg.starts_with("data:") {
|
||||
continue;
|
||||
}
|
||||
let res: Res = if let Some(m) = msg.strip_prefix("data:") {
|
||||
match serde_json::from_str(m) {
|
||||
Ok(r) => r,
|
||||
Err(_) => continue,
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if res.choices.is_empty() {
|
||||
continue;
|
||||
}
|
||||
for choice in &res.choices {
|
||||
if let Some(delta) = &choice.delta {
|
||||
messages.push(delta.content.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
messages
|
||||
}
|
||||
fn replace_request_msg(&mut self, message: &str) -> String {
|
||||
let config = self.config.as_ref().unwrap();
|
||||
let mut msg = message.to_string();
|
||||
@@ -464,6 +469,13 @@ impl AiDataMasking {
|
||||
}
|
||||
Type::Replace => rule.regex.replace(from_word, &rule.value).to_string(),
|
||||
};
|
||||
if to_word.len() > self.byte_window_size {
|
||||
self.byte_window_size = to_word.len();
|
||||
}
|
||||
if to_word.chars().count() > self.char_window_size {
|
||||
self.char_window_size = to_word.chars().count();
|
||||
}
|
||||
|
||||
replace_pair.push((from_word.to_string(), to_word.clone()));
|
||||
|
||||
if rule.restore && !to_word.is_empty() {
|
||||
@@ -499,6 +511,7 @@ impl HttpContext for AiDataMasking {
|
||||
_end_of_stream: bool,
|
||||
) -> HeaderAction {
|
||||
if has_request_body() {
|
||||
self.set_http_request_header("Content-Length", None);
|
||||
HeaderAction::StopIteration
|
||||
} else {
|
||||
HeaderAction::Continue
|
||||
@@ -512,58 +525,41 @@ impl HttpContext for AiDataMasking {
|
||||
self.set_http_response_header("Content-Length", None);
|
||||
HeaderAction::Continue
|
||||
}
|
||||
fn on_http_response_body(&mut self, body_size: usize, _end_of_stream: bool) -> DataAction {
|
||||
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> DataAction {
|
||||
if !self.stream {
|
||||
return DataAction::Continue;
|
||||
}
|
||||
if let Some(body) = self.get_http_response_body(0, body_size) {
|
||||
self.res_body.extend(&body);
|
||||
|
||||
if let Ok(body_str) = String::from_utf8(self.res_body.clone()) {
|
||||
if self.is_openai {
|
||||
let messages = self.process_sse_message(&body_str);
|
||||
|
||||
if self.check_message(&messages.join("")) {
|
||||
if body_size > 0 {
|
||||
if let Some(body) = self.get_http_response_body(0, body_size) {
|
||||
if self.is_openai && self.is_openai_stream.is_none() {
|
||||
self.is_openai_stream = Some(body.starts_with(b"data:"));
|
||||
}
|
||||
self.msg_window.push(&body, self.is_openai_stream.unwrap());
|
||||
if let Ok(mut msg) = String::from_utf8(self.msg_window.message.clone()) {
|
||||
if self.check_message(&msg) {
|
||||
return self.deny(true);
|
||||
}
|
||||
} else if self.check_message(&body_str) {
|
||||
return self.deny(true);
|
||||
}
|
||||
}
|
||||
if self.mask_map.is_empty() {
|
||||
return DataAction::Continue;
|
||||
}
|
||||
if let Ok(body_str) = std::str::from_utf8(&body) {
|
||||
let mut new_str = body_str.to_string();
|
||||
if self.is_openai {
|
||||
let messages = self.process_sse_message(body_str);
|
||||
|
||||
for message in messages {
|
||||
let mut new_message = message.clone();
|
||||
if !self.mask_map.is_empty() {
|
||||
for (from_word, to_word) in self.mask_map.iter() {
|
||||
if let Some(to) = to_word {
|
||||
new_message = new_message.replace(from_word, to);
|
||||
msg = msg.replace(from_word, to);
|
||||
}
|
||||
}
|
||||
if new_message != message {
|
||||
new_str = new_str.replace(
|
||||
&json!(message).to_string(),
|
||||
&json!(new_message).to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (from_word, to_word) in self.mask_map.iter() {
|
||||
if let Some(to) = to_word {
|
||||
new_str = new_str.replace(from_word, to);
|
||||
}
|
||||
}
|
||||
}
|
||||
if new_str != body_str {
|
||||
self.replace_http_response_body(new_str.as_bytes());
|
||||
self.msg_window.message = msg.as_bytes().to_vec();
|
||||
}
|
||||
}
|
||||
}
|
||||
let new_body = if end_of_stream {
|
||||
self.msg_window.finish(self.is_openai_stream.unwrap())
|
||||
} else {
|
||||
self.msg_window.pop(
|
||||
self.char_window_size * 2,
|
||||
self.byte_window_size * 2,
|
||||
self.is_openai_stream.unwrap(),
|
||||
)
|
||||
};
|
||||
self.replace_http_response_body(&new_body);
|
||||
DataAction::Continue
|
||||
}
|
||||
}
|
||||
@@ -586,7 +582,6 @@ impl HttpContextWrapper<AiDataMaskingConfig> for AiDataMasking {
|
||||
return DataAction::Continue;
|
||||
}
|
||||
let config = self.config.as_ref().unwrap();
|
||||
|
||||
let mut req_body = match String::from_utf8(req_body.clone()) {
|
||||
Ok(r) => r,
|
||||
Err(_) => return DataAction::Continue,
|
||||
|
||||
Reference in New Issue
Block a user