[feat] Support redis call with wasm-rust (#1417)

This commit is contained in:
007gzs
2024-10-29 19:35:02 +08:00
committed by GitHub
parent 93c1e5c2bb
commit 2219a17898
18 changed files with 1698 additions and 154 deletions

View File

@@ -30,6 +30,7 @@ use serde::de::DeserializeOwned;
lazy_static! {
static ref LOG: Log = Log::new("plugin_wrapper".to_string());
}
thread_local! {
static HTTP_CALLBACK_DISPATCHER: HttpCallbackDispatcher = HttpCallbackDispatcher::new();
}
@@ -49,7 +50,9 @@ where
None => None,
}
}
fn rule_matcher(&self) -> &SharedRuleMatcher<PluginConfig>;
fn create_http_context_wrapper(
&self,
_context_id: u32,
@@ -63,20 +66,24 @@ pub type HttpCallbackFn = dyn FnOnce(u16, &MultiMap<String, String>, Option<Vec<
pub struct HttpCallbackDispatcher {
call_fns: RefCell<HashMap<u32, Box<HttpCallbackFn>>>,
}
impl Default for HttpCallbackDispatcher {
fn default() -> Self {
Self::new()
}
}
impl HttpCallbackDispatcher {
pub fn new() -> Self {
HttpCallbackDispatcher {
call_fns: RefCell::new(HashMap::new()),
}
}
pub fn set(&self, token_id: u32, arg: Box<HttpCallbackFn>) {
self.call_fns.borrow_mut().insert(token_id, arg);
}
pub fn pop(&self, token_id: u32) -> Option<Box<HttpCallbackFn>> {
self.call_fns.borrow_mut().remove(&token_id)
}
@@ -91,31 +98,39 @@ where
_self_weak: Weak<RefCell<Box<dyn HttpContextWrapper<PluginConfig>>>>,
) {
}
fn log(&self) -> &Log {
&LOG
}
fn on_config(&mut self, _config: Rc<PluginConfig>) {}
fn on_http_request_complete_headers(
&mut self,
_headers: &MultiMap<String, String>,
) -> HeaderAction {
HeaderAction::Continue
}
fn on_http_response_complete_headers(
&mut self,
_headers: &MultiMap<String, String>,
) -> HeaderAction {
HeaderAction::Continue
}
fn cache_request_body(&self) -> bool {
false
}
fn cache_response_body(&self) -> bool {
false
}
fn on_http_request_complete_body(&mut self, _req_body: &Bytes) -> DataAction {
DataAction::Continue
}
fn on_http_response_complete_body(&mut self, _res_body: &Bytes) -> DataAction {
DataAction::Continue
}
@@ -123,6 +138,7 @@ where
fn replace_http_request_body(&mut self, body: &[u8]) {
self.set_http_request_body(0, i32::MAX as usize, body)
}
fn replace_http_response_body(&mut self, body: &[u8]) {
self.set_http_response_body(0, i32::MAX as usize, body)
}
@@ -164,8 +180,8 @@ where
if let Ok(token_id) = ret {
HTTP_CALLBACK_DISPATCHER.with(|dispatcher| dispatcher.set(token_id, call_fn));
self.log().debug(
&format!(
self.log().debugf(
format_args!(
"http call start, id: {}, cluster: {}, method: {}, url: {}, body: {:?}, timeout: {:?}",
token_id, cluster.cluster_name(), method.as_str(), raw_url, body, timeout
)
@@ -173,7 +189,8 @@ where
}
ret
} else {
self.log().critical(&format!("invalid raw_url:{}", raw_url));
self.log()
.criticalf(format_args!("invalid raw_url:{}", raw_url));
Err(Status::ParseFailure)
}
}
@@ -182,14 +199,13 @@ where
downcast_rs::impl_downcast!(HttpContextWrapper<PluginConfig> where PluginConfig: Default + DeserializeOwned + Clone);
pub struct PluginHttpWrapper<PluginConfig> {
req_headers: MultiMap<String, String>,
res_headers: MultiMap<String, String>,
req_body_len: usize,
res_body_len: usize,
config: Option<Rc<PluginConfig>>,
rule_matcher: SharedRuleMatcher<PluginConfig>,
http_content: Rc<RefCell<Box<dyn HttpContextWrapper<PluginConfig>>>>,
}
impl<PluginConfig> PluginHttpWrapper<PluginConfig>
where
PluginConfig: Default + DeserializeOwned + Clone + 'static,
@@ -203,8 +219,6 @@ where
.borrow_mut()
.init_self_weak(Rc::downgrade(&rc_content));
PluginHttpWrapper {
req_headers: MultiMap::new(),
res_headers: MultiMap::new(),
req_body_len: 0,
res_body_len: 0,
config: None,
@@ -212,10 +226,12 @@ where
http_content: rc_content,
}
}
fn get_http_call_fn(&mut self, token_id: u32) -> Option<Box<HttpCallbackFn>> {
HTTP_CALLBACK_DISPATCHER.with(|dispatcher| dispatcher.pop(token_id))
}
}
impl<PluginConfig> Context for PluginHttpWrapper<PluginConfig>
where
PluginConfig: Default + DeserializeOwned + Clone + 'static,
@@ -240,24 +256,24 @@ where
status_code = code;
normal_response = true;
} else {
self.http_content
.borrow()
.log()
.error(&format!("failed to parse status: {}", header_value));
self.http_content.borrow().log().errorf(format_args!(
"failed to parse status: {}",
header_value
));
status_code = 500;
}
}
headers.insert(k, header_value);
}
Err(_) => {
self.http_content.borrow().log().warn(&format!(
self.http_content.borrow().log().warnf(format_args!(
"http call response header contains non-ASCII characters header: {}",
k
));
}
}
}
self.http_content.borrow().log().warn(&format!(
self.http_content.borrow().log().debugf(format_args!(
"http call end, id: {}, code: {}, normal: {}, body: {:?}", /* */
token_id, status_code, normal_response, body
));
@@ -277,21 +293,25 @@ where
.borrow_mut()
.on_grpc_call_response(token_id, status_code, response_size)
}
fn on_grpc_stream_initial_metadata(&mut self, token_id: u32, num_elements: u32) {
self.http_content
.borrow_mut()
.on_grpc_stream_initial_metadata(token_id, num_elements)
}
fn on_grpc_stream_message(&mut self, token_id: u32, message_size: usize) {
self.http_content
.borrow_mut()
.on_grpc_stream_message(token_id, message_size)
}
fn on_grpc_stream_trailing_metadata(&mut self, token_id: u32, num_elements: u32) {
self.http_content
.borrow_mut()
.on_grpc_stream_trailing_metadata(token_id, num_elements)
}
fn on_grpc_stream_close(&mut self, token_id: u32, status_code: u32) {
self.http_content
.borrow_mut()
@@ -302,6 +322,7 @@ where
self.http_content.borrow_mut().on_done()
}
}
impl<PluginConfig> HttpContext for PluginHttpWrapper<PluginConfig>
where
PluginConfig: Default + DeserializeOwned + Clone + 'static,
@@ -312,13 +333,15 @@ where
if self.config.is_none() {
return HeaderAction::Continue;
}
let mut req_headers = MultiMap::new();
for (k, v) in self.get_http_request_headers_bytes() {
match String::from_utf8(v) {
Ok(header_value) => {
self.req_headers.insert(k, header_value);
req_headers.insert(k, header_value);
}
Err(_) => {
self.http_content.borrow().log().warn(&format!(
self.http_content.borrow().log().warnf(format_args!(
"request http header contains non-ASCII characters header: {}",
k
));
@@ -338,7 +361,7 @@ where
}
self.http_content
.borrow_mut()
.on_http_request_complete_headers(&self.req_headers)
.on_http_request_complete_headers(&req_headers)
}
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> DataAction {
@@ -383,13 +406,15 @@ where
if self.config.is_none() {
return HeaderAction::Continue;
}
let mut res_headers = MultiMap::new();
for (k, v) in self.get_http_response_headers_bytes() {
match String::from_utf8(v) {
Ok(header_value) => {
self.res_headers.insert(k, header_value);
res_headers.insert(k, header_value);
}
Err(_) => {
self.http_content.borrow().log().warn(&format!(
self.http_content.borrow().log().warnf(format_args!(
"response http header contains non-ASCII characters header: {}",
k
));
@@ -406,7 +431,7 @@ where
}
self.http_content
.borrow_mut()
.on_http_response_complete_headers(&self.res_headers)
.on_http_response_complete_headers(&res_headers)
}
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> DataAction {