Change http_content to Rc in HttpWrapper (#1391)

This commit is contained in:
007gzs
2024-10-21 09:44:01 +08:00
committed by GitHub
parent 32e5a59ae0
commit d96994767c
11 changed files with 446 additions and 304 deletions

View File

@@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::{Rc, Weak};
use std::time::Duration;
use crate::cluster_wrapper::Cluster;
@@ -28,10 +30,13 @@ use serde::de::DeserializeOwned;
lazy_static! {
static ref LOG: Log = Log::new("plugin_wrapper".to_string());
}
thread_local! {
static HTTP_CALLBACK_DISPATCHER: HttpCallbackDispatcher = HttpCallbackDispatcher::new();
}
pub trait RootContextWrapper<PluginConfig, HttpCallArg: 'static = ()>: RootContext
pub trait RootContextWrapper<PluginConfig>: RootContext
where
PluginConfig: Default + DeserializeOwned + 'static + Clone,
PluginConfig: Default + DeserializeOwned + Clone + 'static,
{
// fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
fn create_http_context_use_wrapper(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
@@ -48,38 +53,48 @@ where
fn create_http_context_wrapper(
&self,
_context_id: u32,
) -> Option<Box<dyn HttpContextWrapper<PluginConfig, HttpCallArg>>> {
) -> Option<Box<dyn HttpContextWrapper<PluginConfig>>> {
None
}
}
pub type HttpCallbackFn<T> = dyn FnOnce(&mut T, u16, &MultiMap<String, String>, Option<Vec<u8>>);
pub struct HttpCallArgStorage<HttpCallArg> {
args: HashMap<u32, HttpCallArg>,
pub type HttpCallbackFn = dyn FnOnce(u16, &MultiMap<String, String>, Option<Vec<u8>>);
pub struct HttpCallbackDispatcher {
call_fns: RefCell<HashMap<u32, Box<HttpCallbackFn>>>,
}
impl<HttpCallArg> Default for HttpCallArgStorage<HttpCallArg> {
impl Default for HttpCallbackDispatcher {
fn default() -> Self {
Self::new()
}
}
impl<HttpCallArg> HttpCallArgStorage<HttpCallArg> {
impl HttpCallbackDispatcher {
pub fn new() -> Self {
HttpCallArgStorage {
args: HashMap::new(),
HttpCallbackDispatcher {
call_fns: RefCell::new(HashMap::new()),
}
}
pub fn set(&mut self, token_id: u32, arg: HttpCallArg) {
self.args.insert(token_id, arg);
pub fn set(&self, token_id: u32, arg: Box<HttpCallbackFn>) {
self.call_fns.borrow_mut().insert(token_id, arg);
}
pub fn pop(&mut self, token_id: u32) -> Option<HttpCallArg> {
self.args.remove(&token_id)
pub fn pop(&self, token_id: u32) -> Option<Box<HttpCallbackFn>> {
self.call_fns.borrow_mut().remove(&token_id)
}
}
pub trait HttpContextWrapper<PluginConfig, HttpCallArg = ()>: HttpContext {
pub trait HttpContextWrapper<PluginConfig>: HttpContext
where
PluginConfig: Default + DeserializeOwned + Clone + 'static,
{
fn init_self_weak(
&mut self,
_self_weak: Weak<RefCell<Box<dyn HttpContextWrapper<PluginConfig>>>>,
) {
}
fn log(&self) -> &Log {
&LOG
}
fn on_config(&mut self, _config: &PluginConfig) {}
fn on_config(&mut self, _config: Rc<PluginConfig>) {}
fn on_http_request_complete_headers(
&mut self,
_headers: &MultiMap<String, String>,
@@ -105,16 +120,6 @@ pub trait HttpContextWrapper<PluginConfig, HttpCallArg = ()>: HttpContext {
DataAction::Continue
}
#[allow(clippy::too_many_arguments)]
fn on_http_call_response_detail(
&mut self,
_token_id: u32,
_arg: HttpCallArg,
_status_code: u16,
_headers: &MultiMap<String, String>,
_body: Option<Vec<u8>>,
) {
}
fn replace_http_request_body(&mut self, body: &[u8]) {
self.set_http_request_body(0, i32::MAX as usize, body)
}
@@ -122,10 +127,6 @@ pub trait HttpContextWrapper<PluginConfig, HttpCallArg = ()>: HttpContext {
self.set_http_response_body(0, i32::MAX as usize, body)
}
fn get_http_call_storage(&mut self) -> Option<&mut HttpCallArgStorage<HttpCallArg>> {
None
}
#[allow(clippy::too_many_arguments)]
fn http_call(
&mut self,
@@ -134,7 +135,7 @@ pub trait HttpContextWrapper<PluginConfig, HttpCallArg = ()>: HttpContext {
raw_url: &str,
headers: MultiMap<String, String>,
body: Option<&[u8]>,
arg: HttpCallArg,
call_fn: Box<HttpCallbackFn>,
timeout: Duration,
) -> Result<u32, Status> {
if let Ok(uri) = raw_url.parse::<Uri>() {
@@ -162,17 +163,13 @@ pub trait HttpContextWrapper<PluginConfig, HttpCallArg = ()>: HttpContext {
);
if let Ok(token_id) = ret {
if let Some(storage) = self.get_http_call_storage() {
storage.set(token_id, arg);
self.log().debug(
&format!(
"http call start, id: {}, cluster: {}, method: {}, url: {}, body: {:?}, timeout: {:?}",
token_id, cluster.cluster_name(), method.as_str(), raw_url, body, timeout
)
);
} else {
return Err(Status::InternalFailure);
}
HTTP_CALLBACK_DISPATCHER.with(|dispatcher| dispatcher.set(token_id, call_fn));
self.log().debug(
&format!(
"http call start, id: {}, cluster: {}, method: {}, url: {}, body: {:?}, timeout: {:?}",
token_id, cluster.cluster_name(), method.as_str(), raw_url, body, timeout
)
);
}
ret
} else {
@@ -181,20 +178,30 @@ pub trait HttpContextWrapper<PluginConfig, HttpCallArg = ()>: HttpContext {
}
}
}
pub struct PluginHttpWrapper<PluginConfig, HttpCallArg = ()> {
downcast_rs::impl_downcast!(HttpContextWrapper<PluginConfig> where PluginConfig: Default + DeserializeOwned + Clone);
pub struct PluginHttpWrapper<PluginConfig> {
req_headers: MultiMap<String, String>,
res_headers: MultiMap<String, String>,
req_body_len: usize,
res_body_len: usize,
config: Option<PluginConfig>,
config: Option<Rc<PluginConfig>>,
rule_matcher: SharedRuleMatcher<PluginConfig>,
http_content: Box<dyn HttpContextWrapper<PluginConfig, HttpCallArg>>,
http_content: Rc<RefCell<Box<dyn HttpContextWrapper<PluginConfig>>>>,
}
impl<PluginConfig, HttpCallArg> PluginHttpWrapper<PluginConfig, HttpCallArg> {
impl<PluginConfig> PluginHttpWrapper<PluginConfig>
where
PluginConfig: Default + DeserializeOwned + Clone + 'static,
{
pub fn new(
rule_matcher: &SharedRuleMatcher<PluginConfig>,
http_content: Box<dyn HttpContextWrapper<PluginConfig, HttpCallArg>>,
http_content: Box<dyn HttpContextWrapper<PluginConfig>>,
) -> Self {
let rc_content = Rc::new(RefCell::new(http_content));
rc_content
.borrow_mut()
.init_self_weak(Rc::downgrade(&rc_content));
PluginHttpWrapper {
req_headers: MultiMap::new(),
res_headers: MultiMap::new(),
@@ -202,18 +209,17 @@ impl<PluginConfig, HttpCallArg> PluginHttpWrapper<PluginConfig, HttpCallArg> {
res_body_len: 0,
config: None,
rule_matcher: rule_matcher.clone(),
http_content,
http_content: rc_content,
}
}
fn get_http_call_arg(&mut self, token_id: u32) -> Option<HttpCallArg> {
if let Some(storage) = self.http_content.get_http_call_storage() {
storage.pop(token_id)
} else {
None
}
fn get_http_call_fn(&mut self, token_id: u32) -> Option<Box<HttpCallbackFn>> {
HTTP_CALLBACK_DISPATCHER.with(|dispatcher| dispatcher.pop(token_id))
}
}
impl<PluginConfig, HttpCallArg> Context for PluginHttpWrapper<PluginConfig, HttpCallArg> {
impl<PluginConfig> Context for PluginHttpWrapper<PluginConfig>
where
PluginConfig: Default + DeserializeOwned + Clone + 'static,
{
fn on_http_call_response(
&mut self,
token_id: u32,
@@ -221,7 +227,7 @@ impl<PluginConfig, HttpCallArg> Context for PluginHttpWrapper<PluginConfig, Http
body_size: usize,
num_trailers: usize,
) {
if let Some(arg) = self.get_http_call_arg(token_id) {
if let Some(call_fn) = self.get_http_call_fn(token_id) {
let body = self.get_http_call_response_body(0, body_size);
let mut headers = MultiMap::new();
let mut status_code = 502;
@@ -235,6 +241,7 @@ impl<PluginConfig, HttpCallArg> Context for PluginHttpWrapper<PluginConfig, Http
normal_response = true;
} else {
self.http_content
.borrow()
.log()
.error(&format!("failed to parse status: {}", header_value));
status_code = 500;
@@ -243,58 +250,61 @@ impl<PluginConfig, HttpCallArg> Context for PluginHttpWrapper<PluginConfig, Http
headers.insert(k, header_value);
}
Err(_) => {
self.http_content.log().warn(&format!(
self.http_content.borrow().log().warn(&format!(
"http call response header contains non-ASCII characters header: {}",
k
));
}
}
}
self.http_content.log().warn(&format!(
"http call end, id: {}, code: {}, normal: {}, body: {:?}",
self.http_content.borrow().log().warn(&format!(
"http call end, id: {}, code: {}, normal: {}, body: {:?}", /* */
token_id, status_code, normal_response, body
));
self.http_content.on_http_call_response_detail(
token_id,
arg,
status_code,
&headers,
body,
)
call_fn(status_code, &headers, body)
} else {
self.http_content
.on_http_call_response(token_id, num_headers, body_size, num_trailers)
self.http_content.borrow_mut().on_http_call_response(
token_id,
num_headers,
body_size,
num_trailers,
)
}
}
fn on_grpc_call_response(&mut self, token_id: u32, status_code: u32, response_size: usize) {
self.http_content
.borrow_mut()
.on_grpc_call_response(token_id, status_code, response_size)
}
fn on_grpc_stream_initial_metadata(&mut self, token_id: u32, num_elements: u32) {
self.http_content
.borrow_mut()
.on_grpc_stream_initial_metadata(token_id, num_elements)
}
fn on_grpc_stream_message(&mut self, token_id: u32, message_size: usize) {
self.http_content
.borrow_mut()
.on_grpc_stream_message(token_id, message_size)
}
fn on_grpc_stream_trailing_metadata(&mut self, token_id: u32, num_elements: u32) {
self.http_content
.borrow_mut()
.on_grpc_stream_trailing_metadata(token_id, num_elements)
}
fn on_grpc_stream_close(&mut self, token_id: u32, status_code: u32) {
self.http_content
.borrow_mut()
.on_grpc_stream_close(token_id, status_code)
}
fn on_done(&mut self) -> bool {
self.http_content.on_done()
self.http_content.borrow_mut().on_done()
}
}
impl<PluginConfig, HttpCallArg> HttpContext for PluginHttpWrapper<PluginConfig, HttpCallArg>
impl<PluginConfig> HttpContext for PluginHttpWrapper<PluginConfig>
where
PluginConfig: Default + DeserializeOwned + Clone,
PluginConfig: Default + DeserializeOwned + Clone + 'static,
{
fn on_http_request_headers(&mut self, num_headers: usize, end_of_stream: bool) -> HeaderAction {
let binding = self.rule_matcher.borrow();
@@ -306,7 +316,7 @@ where
self.req_headers.insert(k, header_value);
}
Err(_) => {
self.http_content.log().warn(&format!(
self.http_content.borrow().log().warn(&format!(
"request http header contains non-ASCII characters header: {}",
k
));
@@ -315,22 +325,25 @@ where
}
if let Some(config) = &self.config {
self.http_content.on_config(config);
self.http_content.borrow_mut().on_config(config.clone());
}
let ret = self
.http_content
.borrow_mut()
.on_http_request_headers(num_headers, end_of_stream);
if ret != HeaderAction::Continue {
return ret;
}
self.http_content
.borrow_mut()
.on_http_request_complete_headers(&self.req_headers)
}
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> DataAction {
if !self.http_content.cache_request_body() {
if !self.http_content.borrow().cache_request_body() {
return self
.http_content
.borrow_mut()
.on_http_request_body(body_size, end_of_stream);
}
self.req_body_len += body_size;
@@ -343,11 +356,15 @@ where
req_body = body;
}
}
self.http_content.on_http_request_complete_body(&req_body)
self.http_content
.borrow_mut()
.on_http_request_complete_body(&req_body)
}
fn on_http_request_trailers(&mut self, num_trailers: usize) -> Action {
self.http_content.on_http_request_trailers(num_trailers)
self.http_content
.borrow_mut()
.on_http_request_trailers(num_trailers)
}
fn on_http_response_headers(
@@ -361,7 +378,7 @@ where
self.res_headers.insert(k, header_value);
}
Err(_) => {
self.http_content.log().warn(&format!(
self.http_content.borrow().log().warn(&format!(
"response http header contains non-ASCII characters header: {}",
k
));
@@ -371,18 +388,21 @@ where
let ret = self
.http_content
.borrow_mut()
.on_http_response_headers(num_headers, end_of_stream);
if ret != HeaderAction::Continue {
return ret;
}
self.http_content
.borrow_mut()
.on_http_response_complete_headers(&self.res_headers)
}
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> DataAction {
if !self.http_content.cache_response_body() {
if !self.http_content.borrow().cache_response_body() {
return self
.http_content
.borrow_mut()
.on_http_response_body(body_size, end_of_stream);
}
self.res_body_len += body_size;
@@ -397,14 +417,18 @@ where
res_body = body;
}
}
self.http_content.on_http_response_complete_body(&res_body)
self.http_content
.borrow_mut()
.on_http_response_complete_body(&res_body)
}
fn on_http_response_trailers(&mut self, num_trailers: usize) -> Action {
self.http_content.on_http_response_trailers(num_trailers)
self.http_content
.borrow_mut()
.on_http_response_trailers(num_trailers)
}
fn on_log(&mut self) {
self.http_content.on_log()
self.http_content.borrow_mut().on_log()
}
}