Files
higress/plugins/wasm-rust/extensions/ai-data-masking/src/msg_win_openai.rs

357 lines
12 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Copyright (c) 2025 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use higress_wasm_rust::event_stream::EventStream;
use serde::Deserialize;
use serde_json::Value;
use crate::msg_window::MessageWindow;
use crate::number_merge::NumberMerge;
#[derive(PartialEq, Eq, Clone, Copy)]
enum MsgFlag {
None,
Content,
ReasoningContent,
}
impl Default for MsgFlag {
fn default() -> Self {
Self::None
}
}
#[derive(Deserialize)]
struct Delta {
#[serde(default)]
content: Option<String>,
#[serde(default)]
reasoning_content: Option<String>,
}
#[derive(Deserialize)]
struct Choices {
#[serde(default)]
index: i64,
#[serde(default)]
delta: Option<Delta>,
#[serde(default)]
finish_reason: Option<String>,
}
impl Delta {
fn get_flag_msg(&self, default_flag: &MsgFlag) -> (MsgFlag, &[u8]) {
if let Some(msg) = &self.content {
if !msg.is_empty() {
return (MsgFlag::Content, msg.as_bytes());
}
}
if let Some(msg) = &self.reasoning_content {
if !msg.is_empty() {
return (MsgFlag::ReasoningContent, msg.as_bytes());
}
}
(*default_flag, &[])
}
}
const USAGE_PATH: &str = "usage";
const CHOICES_PATH: &str = "choices";
type MessageLine = Vec<(MsgFlag, Vec<u8>)>;
#[derive(Default)]
struct MessageWindowOpenAi {
message_window: MessageWindow,
ret_messages: MessageLine,
flag: MsgFlag,
last_value: Value,
finish_reason: Option<String>,
}
impl MessageWindowOpenAi {
fn update(
&mut self,
data: &[u8],
flag: MsgFlag,
value: &Value,
finish_reason: &Option<String>,
) {
self.last_value = value.clone();
if data.is_empty() {
return;
}
if self.flag == MsgFlag::None {
self.flag = flag;
}
if self.flag != flag {
let last_flag = core::mem::replace(&mut self.flag, flag);
let msg = self.message_window.finish();
self.ret_messages.push((last_flag, msg));
}
self.message_window.update(data);
if let Some(fr) = finish_reason {
self.finish_reason = Some(fr.clone());
}
}
fn gen_value(&self, flag: &MsgFlag, msg: &[u8], finish: bool) -> Value {
let mut ret = self.last_value.clone();
match flag {
MsgFlag::Content => {
ret["delta"]["content"] = Value::String(String::from_utf8_lossy(msg).to_string());
if let Some(m) = ret["delta"].as_object_mut() {
m.remove("reasoning_content");
}
}
MsgFlag::ReasoningContent => {
ret["delta"]["reasoning_content"] =
Value::String(String::from_utf8_lossy(msg).to_string());
ret["delta"]["content"] = Value::String(String::new());
}
_ => {}
}
if finish {
ret["finish_reason"] = self
.finish_reason
.as_ref()
.map_or(Value::Null, |v| Value::String(v.to_string()));
} else {
ret["finish_reason"] = Value::Null;
}
ret
}
fn messages_to_value(&mut self) -> Vec<Value> {
let mut ret = Vec::new();
for (flag, msg) in core::mem::take(&mut self.ret_messages) {
ret.push(self.gen_value(&flag, &msg, false));
}
ret
}
fn pop(&mut self, char_window_size: usize, byte_window_size: usize) -> Vec<Value> {
let mut ret = self.messages_to_value();
let msg = self.message_window.pop(char_window_size, byte_window_size);
if !msg.is_empty() {
ret.push(self.gen_value(&self.flag, &msg, false));
}
ret
}
fn finish(&mut self) -> Vec<Value> {
let mut ret = self.messages_to_value();
let msg = self.message_window.finish();
let flag = core::mem::replace(&mut self.flag, MsgFlag::None);
ret.push(self.gen_value(&flag, &msg, true));
ret
}
fn iter_mut(&mut self) -> impl Iterator<Item = &mut Vec<u8>> {
self.ret_messages
.iter_mut()
.map(|(_, msg)| msg)
.chain(self.message_window.iter_mut())
}
}
#[derive(Default)]
pub(crate) struct MsgWindow {
stream_parser: EventStream,
base_message_window: MessageWindow,
message_windows: HashMap<i64, MessageWindowOpenAi>,
last_value: Value,
usage: NumberMerge,
}
impl MsgWindow {
fn update_event(&mut self, event: Vec<u8>) -> Option<Vec<u8>> {
if event.is_empty() || !event.starts_with(b"data:") {
Some(event)
} else if let Ok(res) = serde_json::from_slice::<Value>(&event[b"data:".len()..]) {
self.last_value = res;
if let Some(r) = self.last_value.as_object() {
if let Some(v) = r.get(USAGE_PATH) {
self.usage.add(v);
}
if let Some(v) = r.get(CHOICES_PATH) {
if let Some(a) = v.as_array() {
for item in a {
if let Ok(c) = serde_json::from_value::<Choices>(item.clone()) {
if let Some(d) = &c.delta {
let mw = self.message_windows.entry(c.index).or_default();
let (flag, msg) = d.get_flag_msg(&mw.flag);
mw.update(msg, flag, item, &c.finish_reason);
}
}
}
}
}
}
None
} else if event.starts_with(b"data: [DONE]") {
None
} else {
Some(event)
}
}
fn push_base(&mut self, data: &[u8]) {
self.base_message_window.update(data);
}
pub(crate) fn push(&mut self, data: &[u8], is_openai: bool) {
if is_openai {
self.stream_parser.update(data.to_vec());
while let Some(event) = self.stream_parser.next() {
if let Some(msg) = self.update_event(event) {
self.push_base(&msg);
}
}
} else {
self.push_base(data);
}
}
pub(crate) fn pop(
&mut self,
char_window_size: usize,
byte_window_size: usize,
is_openai: bool,
) -> Vec<u8> {
if !is_openai {
return self
.base_message_window
.pop(char_window_size, byte_window_size);
}
let mut ret = Vec::new();
for mw in self.message_windows.values_mut() {
for value in mw.pop(char_window_size, byte_window_size) {
let usage = self.usage.finish();
let mut ret_value = self.last_value.clone();
ret_value[CHOICES_PATH] = Value::Array(vec![value]);
ret_value[USAGE_PATH] = usage;
ret.extend(format!("data: {}\n\n", ret_value).as_bytes())
}
}
ret
}
pub(crate) fn finish(&mut self, is_openai: bool) -> Vec<u8> {
if !is_openai {
return self.base_message_window.finish();
}
if let Some(event) = self.stream_parser.flush() {
self.update_event(event);
}
let mut ret = Vec::new();
for mw in &mut self.message_windows.values_mut() {
for value in mw.finish() {
let usage = self.usage.finish();
let mut ret_value = self.last_value.clone();
ret_value[CHOICES_PATH] = Value::Array(vec![value]);
ret_value[USAGE_PATH] = usage;
ret.extend(format!("data: {}\n\n", ret_value).as_bytes())
}
}
ret
}
pub(crate) fn messages_iter_mut(&mut self) -> impl Iterator<Item = &mut Vec<u8>> {
self.base_message_window.iter_mut().chain(
self.message_windows
.values_mut()
.flat_map(|mw| mw.iter_mut()),
)
}
}
#[cfg(test)]
mod tests {
use rust_embed::Embed;
use super::*;
#[derive(Embed)]
#[folder = "test/"]
struct Asset;
#[derive(Deserialize)]
struct Res {
choices: Vec<Choices>,
}
impl Res {
fn get_text(&self) -> (String, String) {
let mut content = String::new();
let mut reasoning_content = String::new();
for choice in self.choices.iter() {
if let Some(delta) = &choice.delta {
if let Some(c) = &delta.content {
content += c;
}
if let Some(rc) = &delta.reasoning_content {
reasoning_content += rc;
}
}
}
(content, reasoning_content)
}
}
#[test]
fn test_msg() {
let mut msg_win = MsgWindow::default();
let data = raw_message("raw_message.txt");
let mut buffer = Vec::new();
for line in data.split("\n") {
msg_win.push(line.as_bytes(), true);
msg_win.push(b"\n\n", true);
for message in msg_win.messages_iter_mut() {
if let Ok(mut msg) = String::from_utf8(message.clone()) {
msg = msg.replace("Higress", "***higress***");
message.clear();
message.extend_from_slice(msg.as_bytes());
}
}
buffer.extend(msg_win.pop(7, 7, true));
}
buffer.extend(msg_win.finish(true));
let mut message = String::new();
let mut reasoning_message = String::new();
for line in buffer.split(|&x| x == b'\n') {
if line.is_empty() {
continue;
}
assert!(line.starts_with(b"data:"));
if line.starts_with(b"data: [DONE]") {
continue;
}
let des = serde_json::from_slice::<Res>(&line[b"data:".len()..]);
assert!(des.is_ok());
let res = des.unwrap();
let (c, rc) = res.get_text();
message.push_str(&c);
reasoning_message.push_str(&rc);
}
let res = "***higress*** 是一个基于 Istio 的高性能服务网格数据平面项目,旨在提供高吞吐量、低延迟和可扩展的服务通信管理。它为企业级应用提供了丰富的流量治理功能,如负载均衡、熔断、限流等,并支持多协议代理(包括 HTTP/1.1, HTTP/2, gRPC。***higress*** 的设计目标是优化 Istio 在大规模集群中的性能表现,满足高并发场景下的需求。";
assert_eq!(message, res);
assert_eq!(reasoning_message, res);
}
fn raw_message(file_name: &str) -> String {
if let Some(file) = Asset::get(file_name) {
if let Ok(data) = std::str::from_utf8(file.data.as_ref()) {
return data.to_string();
}
}
String::new()
}
}