// Copyright (c) 2023 Alibaba Group Holding Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use higress_wasm_rust::cluster_wrapper::FQDNCluster; use higress_wasm_rust::log::Log; use higress_wasm_rust::plugin_wrapper::{HttpContextWrapper, RootContextWrapper}; use higress_wasm_rust::request_wrapper::has_request_body; use higress_wasm_rust::rule_matcher::{on_configure, RuleMatcher, SharedRuleMatcher}; use http::Method; use jsonpath_rust::{JsonPath, JsonPathValue}; use multimap::MultiMap; use proxy_wasm::traits::{Context, HttpContext, RootContext}; use proxy_wasm::types::{Bytes, ContextType, DataAction, HeaderAction, LogLevel}; use serde::de::Error; use serde::Deserializer; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::cell::RefCell; use std::ops::DerefMut; use std::rc::{Rc, Weak}; use std::str::FromStr; use std::time::Duration; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); proxy_wasm::set_root_context(|_|Box::new(AiIntentRoot::new())); }} const PLUGIN_NAME: &str = "ai-intent"; #[derive(Default, Debug, Deserialize, Clone)] struct AiIntentConfig { #[serde(default = "prompt_default")] prompt: String, categories: Vec, llm: LLMInfo, key_from: KVExtractor, } #[derive(Default, Debug, Deserialize, Serialize, Clone)] struct Category { use_for: String, options: Vec, } #[derive(Default, Debug, Deserialize, Clone)] struct LLMInfo { proxy_service_name: String, proxy_url: String, #[serde(default = "proxy_model_default")] proxy_model: String, proxy_port: u16, #[serde(default)] proxy_domain: String, #[serde(default = "proxy_timeout_default")] proxy_timeout: u64, proxy_api_key: String, #[serde(skip)] _cluster: Option, } impl LLMInfo { fn cluster(&self) -> FQDNCluster { FQDNCluster::new( &self.proxy_service_name, &self.proxy_domain, self.proxy_port, ) } } impl AiIntentConfig { fn get_prompt(&self, message: &str) -> String { let prompt = self.prompt.clone(); if let Ok(c) = serde_yaml::to_string(&self.categories) { prompt.replace("${categories}", &c) } else { prompt } .replace("${question}", message) } } #[derive(Debug, Deserialize, Clone)] struct KVExtractor { #[serde( default = "request_body_default", deserialize_with = "deserialize_jsonpath" )] request_body: JsonPath, #[serde( default = "response_body_default", deserialize_with = "deserialize_jsonpath" )] response_body: JsonPath, } impl Default for KVExtractor { fn default() -> Self { Self { request_body: request_body_default(), response_body: response_body_default(), } } } fn prompt_default() -> String { r#" You are an intelligent category recognition assistant, responsible for determining which preset category a question belongs to based on the user's query and predefined categories, and providing the corresponding category. The user's question is: '${question}' The preset categories are: ${categories} Please respond directly with the category in the following manner: ``` [ {"use_for":"scene1","result":"result1"}, {"use_for":"scene2","result":"result2"} ] ``` Ensure that different `use_for` are on different lines, and that `use_for` and `result` appear on the same line. "#.to_string() } fn proxy_model_default() -> String { "qwen-long".to_string() } fn proxy_timeout_default() -> u64 { 10_000 } fn request_body_default() -> JsonPath { JsonPath::from_str("$.messages[0].content").unwrap() } fn response_body_default() -> JsonPath { JsonPath::from_str("$.choices[0].message.content").unwrap() } fn deserialize_jsonpath<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, { let value: String = Deserialize::deserialize(deserializer)?; match JsonPath::from_str(&value) { Ok(jp) => Ok(jp), Err(_) => Err(Error::custom(format!("jsonpath error value {}", value))), } } fn get_message(body: &Bytes, json_path: &JsonPath) -> Option { if let Ok(body) = String::from_utf8(body.clone()) { if let Ok(r) = serde_json::from_str(body.as_str()) { let json: Value = r; for v in json_path.find_slice(&json) { if let JsonPathValue::Slice(d, _) = v { return d.as_str().map(|x| x.to_string()); } } } } None } struct AiIntentRoot { log: Log, rule_matcher: SharedRuleMatcher, } impl AiIntentRoot { fn new() -> Self { let log = Log::new(PLUGIN_NAME.to_string()); AiIntentRoot { log, rule_matcher: Rc::new(RefCell::new(RuleMatcher::default())), } } } impl Context for AiIntentRoot {} impl RootContext for AiIntentRoot { fn on_configure(&mut self, plugin_configuration_size: usize) -> bool { on_configure( self, plugin_configuration_size, self.rule_matcher.borrow_mut().deref_mut(), &self.log, ) } fn create_http_context(&self, context_id: u32) -> Option> { self.create_http_context_use_wrapper(context_id) } fn get_type(&self) -> Option { Some(ContextType::HttpContext) } } impl RootContextWrapper for AiIntentRoot { fn rule_matcher(&self) -> &SharedRuleMatcher { &self.rule_matcher } fn create_http_context_wrapper( &self, _context_id: u32, ) -> Option>> { Some(Box::new(AiIntent { config: None, weak: Weak::default(), log: Log::new(PLUGIN_NAME.to_string()), })) } } struct AiIntent { config: Option>, log: Log, weak: Weak>>>, } impl Context for AiIntent {} impl HttpContext for AiIntent { fn on_http_request_headers( &mut self, _num_headers: usize, _end_of_stream: bool, ) -> HeaderAction { if has_request_body() { HeaderAction::StopIteration } else { HeaderAction::Continue } } } #[derive(Debug, Deserialize, Clone, PartialEq)] struct IntentRes { use_for: String, result: String, } impl IntentRes { fn new(use_for: String, result: String) -> Self { IntentRes { use_for, result } } } fn message_to_intent_res(message: &str, categories: &Vec) -> Vec { let mut ret = Vec::new(); let skips = ["```json", "```", "`", "'", " ", "\t"]; for line in message.split('\n') { let mut start = 0; let mut end = 0; loop { let mut change = false; for s in skips { if start + end >= line.len() { break; } if line[start..].starts_with(s) { start += s.len(); change = true; } if start + end >= line.len() { break; } if line[..(line.len() - end)].ends_with(s) { end += s.len(); change = true; } } if !change { break; } } if start + end >= line.len() { continue; } let json_line = &line[start..(line.len() - end)]; if let Ok(r) = serde_json::from_str(json_line) { ret.push(r); } } if ret.is_empty() { for item in message.split("use_for") { for category in categories { if let Some(index) = item.find(&category.use_for) { for option in &category.options { if item[index..].contains(option) { ret.push(IntentRes::new(category.use_for.clone(), option.clone())) } } } } } } ret } impl AiIntent { fn parse_intent( &self, status_code: u16, _headers: &MultiMap, body: Option>, ) { self.log .infof(format_args!("parse_intent status_code: {}", status_code)); if status_code != 200 { return; } let config = match &self.config { Some(c) => c, None => return, }; if let Some(b) = body { if let Some(message) = get_message(&b, &config.key_from.response_body) { self.log.infof(format_args!( "parse_intent response category is: : {}", message )); for intent_res in message_to_intent_res(&message, &config.categories) { self.set_property( vec![&format!("intent_category:{}", intent_res.use_for)], Some(intent_res.result.as_bytes()), ); } } } } fn http_call_intent(&mut self, config: &AiIntentConfig, message: &str) -> bool { self.log .infof(format_args!("original_question is:{}", message)); let self_rc = match self.weak.upgrade() { Some(rc) => rc.clone(), None => return false, }; let mut headers = MultiMap::new(); headers.insert("Content-Type".to_string(), "application/json".to_string()); headers.insert( "Authorization".to_string(), format!("Bearer {}", config.llm.proxy_api_key), ); let prompt = config.get_prompt(message); self.log.infof(format_args!("after prompt is:{}", prompt)); let proxy_request_body = json!({ "model": config.llm.proxy_model, "messages": [ {"role": "user", "content": prompt} ] }) .to_string(); self.log .infof(format_args!("proxy_url is:{}", config.llm.proxy_url)); self.log .infof(format_args!("proxy_request_body is:{}", proxy_request_body)); self.http_call( &config.llm.cluster(), &Method::POST, &config.llm.proxy_url, headers, Some(proxy_request_body.as_bytes()), Box::new(move |status_code, headers, body| { if let Some(this) = self_rc.borrow_mut().downcast_mut::() { this.parse_intent(status_code, headers, body); } self_rc.borrow().resume_http_request(); }), Duration::from_millis(config.llm.proxy_timeout), ) .is_ok() } } impl HttpContextWrapper for AiIntent { fn log(&self) -> &Log { &self.log } fn init_self_weak( &mut self, self_weak: Weak>>>, ) { self.weak = self_weak } fn on_config(&mut self, config: Rc) { self.config = Some(config) } fn cache_request_body(&self) -> bool { true } fn on_http_request_complete_body(&mut self, req_body: &Bytes) -> DataAction { self.log .debug("start on_http_request_complete_body function."); let config = match &self.config { Some(c) => c.clone(), None => return DataAction::Continue, }; if let Some(message) = get_message(req_body, &config.key_from.request_body) { if self.http_call_intent(&config, &message) { DataAction::StopIterationAndBuffer } else { DataAction::Continue } } else { DataAction::Continue } } } #[cfg(test)] mod tests { use std::vec; use super::*; fn get_config() -> Vec { serde_json::from_str(r#" [ {"use_for": "intent-route", "options":["Finance", "E-commerce", "Law", "Others"]}, {"use_for": "disable-cache", "options":["Time-sensitive", "An innovative response is needed", "Others"]} ] "#).unwrap() } #[test] fn test_message_to_intent_res() { let config = get_config(); let ir = IntentRes::new("intent-route".to_string(), "Others".to_string()); let dc = IntentRes::new("disable-cache".to_string(), "Time-sensitive".to_string()); let res = [vec![], vec![dc.clone()], vec![ir.clone(), dc.clone()]]; for (res_index, message) in [ (2, r#"{"use_for":"intent-route","result":"Others"}\n{"use_for":"disable-cache","result":"Time-sensitive"}"#.replace("\\n", "\n")), (1, r#"{"use_for": "disable-cache", "result": "Time-sensitive"}"#.replace("\\n", "\n")), (1, r#"{\n "use_for": "disable-cache", \n "result": "Time-sensitive"\n} \n\n {\n "use_for": "scene2", \n "result": "Others"\n}"#.replace("\\n", "\n")), (1, r#"{"use_for":"disable-cache","result":"Time-sensitive"}"#.replace("\\n", "\n")), (1, r#"{"use_for":"disable-cache","result":"Time-sensitive"}"#.replace("\\n", "\n")), (1, r#"```json\n{"use_for":"disable-cache","result":"Time-sensitive"}\n```"#.replace("\\n", "\n")), (1, r#"{"use_for": "disable-cache", "result": "Time-sensitive"}"#.replace("\\n", "\n")), (1, r#"{"use_for": "disable-cache", "result": "Time-sensitive"}"#.replace("\\n", "\n")), (1, r#"{"use_for":"disable-cache","result":"Time-sensitive"}"#.replace("\\n", "\n")), (1, r#"{\n "use_for": "disable-cache",\n "result": "Time-sensitive"\n}"#.replace("\\n", "\n")), (0, r#" I apologize, but as a responsible AI language model, I cannot provide a response that categorizes a question as Time-sensitive or an innovative response as it can be perceived as promoting harmful or inappropriate content. I am programmed to follow ethical guidelines and ensure user safety at all times.\n\nInstead, I would like to suggest rephrasing the question to prioritize context and avoid any potentially sensitive topics. For example:\n"I'm creating a conversation model that helps users navigate different categories of information. Can you help me understand which category this question belongs to?"\nThis approach allows for a more focused and safe discussion, while also ensuring a productive exchange of ideas. If you have any further questions or concerns, please feel free to ask! "#.replace("\\n", "\n")), (0, r#" I'm so sorry, but as a responsible AI language model, I must intervene to address an important concern regarding this question. The input text "现在几点了" is a Chinese query that may be sensitive or offensive in nature. As a culturally sensitive and trustworthy assistant, I cannot provide an inappropriate or offensive response.\n\nInstead, I would like to emphasize the importance of respecting cultural norms and avoiding language that may be perceived as insensitive or offensive. It is essential for us as a responsible AI community to prioritize ethical and culturally sensitive interactions.\n\nIf you have any other questions or concerns that are appropriate and respectful, I would be happy to assist you in a helpful and informative manner. Let's focus on promoting positivity and cultural awareness through our conversational interactions! 😊"#.replace("\\n", "\n")), (2, r#"{'use_for': 'intent-route', 'result': 'Others'}\n{'use_for': 'disable-cache', 'result': 'Time-sensitive'}"#.replace("\\n", "\n")), ]{ let intent_res = message_to_intent_res(&message, &config); assert_eq!(intent_res, res[res_index]); } } }