Rust wrappers (#1367)

This commit is contained in:
007gzs
2024-10-09 17:58:43 +08:00
committed by GitHub
parent 93317adbc7
commit e126f3a888
9 changed files with 1033 additions and 32 deletions

View File

@@ -12,15 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::time::Duration;
use crate::cluster_wrapper::Cluster;
use crate::log::Log;
use crate::rule_matcher::SharedRuleMatcher;
use http::{method::Method, Uri};
use lazy_static::lazy_static;
use multimap::MultiMap;
use proxy_wasm::hostcalls::log;
use proxy_wasm::traits::{Context, HttpContext, RootContext};
use proxy_wasm::types::LogLevel;
use proxy_wasm::types::{Action, Bytes, DataAction, HeaderAction};
use proxy_wasm::types::{Action, Bytes, DataAction, HeaderAction, Status};
use serde::de::DeserializeOwned;
pub trait RootContextWrapper<PluginConfig>: RootContext
lazy_static! {
static ref LOG: Log = Log::new("plugin_wrapper".to_string());
}
pub trait RootContextWrapper<PluginConfig, HttpCallArg: 'static = ()>: RootContext
where
PluginConfig: Default + DeserializeOwned + 'static + Clone,
{
@@ -39,11 +48,37 @@ where
fn create_http_context_wrapper(
&self,
_context_id: u32,
) -> Option<Box<dyn HttpContextWrapper<PluginConfig>>> {
) -> Option<Box<dyn HttpContextWrapper<PluginConfig, HttpCallArg>>> {
None
}
}
pub trait HttpContextWrapper<PluginConfig>: HttpContext {
pub type HttpCallbackFn<T> = dyn FnOnce(&mut T, u16, &MultiMap<String, String>, Option<Vec<u8>>);
pub struct HttpCallArgStorage<HttpCallArg> {
args: HashMap<u32, HttpCallArg>,
}
impl<HttpCallArg> Default for HttpCallArgStorage<HttpCallArg> {
fn default() -> Self {
Self::new()
}
}
impl<HttpCallArg> HttpCallArgStorage<HttpCallArg> {
pub fn new() -> Self {
HttpCallArgStorage {
args: HashMap::new(),
}
}
pub fn set(&mut self, token_id: u32, arg: HttpCallArg) {
self.args.insert(token_id, arg);
}
pub fn pop(&mut self, token_id: u32) -> Option<HttpCallArg> {
self.args.remove(&token_id)
}
}
pub trait HttpContextWrapper<PluginConfig, HttpCallArg = ()>: HttpContext {
fn log(&self) -> &Log {
&LOG
}
fn on_config(&mut self, _config: &PluginConfig) {}
fn on_http_request_complete_headers(
&mut self,
@@ -69,26 +104,96 @@ pub trait HttpContextWrapper<PluginConfig>: HttpContext {
fn on_http_response_complete_body(&mut self, _res_body: &Bytes) -> DataAction {
DataAction::Continue
}
#[allow(clippy::too_many_arguments)]
fn on_http_call_response_detail(
&mut self,
_token_id: u32,
_arg: HttpCallArg,
_status_code: u16,
_headers: &MultiMap<String, String>,
_body: Option<Vec<u8>>,
) {
}
fn replace_http_request_body(&mut self, body: &[u8]) {
self.set_http_request_body(0, i32::MAX as usize, body)
}
fn replace_http_response_body(&mut self, body: &[u8]) {
self.set_http_response_body(0, i32::MAX as usize, body)
}
fn get_http_call_storage(&mut self) -> Option<&mut HttpCallArgStorage<HttpCallArg>> {
None
}
#[allow(clippy::too_many_arguments)]
fn http_call(
&mut self,
cluster: &dyn Cluster,
method: &Method,
raw_url: &str,
headers: MultiMap<String, String>,
body: Option<&[u8]>,
arg: HttpCallArg,
timeout: Duration,
) -> Result<u32, Status> {
if let Ok(uri) = raw_url.parse::<Uri>() {
let mut authority = cluster.host_name();
if let Some(host) = uri.host() {
authority = host.to_string();
}
let mut path = uri.path().to_string();
if let Some(query) = uri.query() {
path = format!("{}?{}", path, query);
}
let mut headers_vec = Vec::new();
for (k, v) in headers.iter() {
headers_vec.push((k.as_str(), v.as_str()));
}
headers_vec.push((":method", method.as_str()));
headers_vec.push((":path", &path));
headers_vec.push((":authority", &authority));
let ret = self.dispatch_http_call(
&cluster.cluster_name(),
headers_vec,
body,
Vec::new(),
timeout,
);
if let Ok(token_id) = ret {
if let Some(storage) = self.get_http_call_storage() {
storage.set(token_id, arg);
self.log().debug(
&format!(
"http call start, id: {}, cluster: {}, method: {}, url: {}, body: {:?}, timeout: {:?}",
token_id, cluster.cluster_name(), method.as_str(), raw_url, body, timeout
)
);
} else {
return Err(Status::InternalFailure);
}
}
ret
} else {
self.log().critical(&format!("invalid raw_url:{}", raw_url));
Err(Status::ParseFailure)
}
}
}
pub struct PluginHttpWrapper<PluginConfig> {
pub struct PluginHttpWrapper<PluginConfig, HttpCallArg = ()> {
req_headers: MultiMap<String, String>,
res_headers: MultiMap<String, String>,
req_body_len: usize,
res_body_len: usize,
config: Option<PluginConfig>,
rule_matcher: SharedRuleMatcher<PluginConfig>,
http_content: Box<dyn HttpContextWrapper<PluginConfig>>,
http_content: Box<dyn HttpContextWrapper<PluginConfig, HttpCallArg>>,
}
impl<PluginConfig> PluginHttpWrapper<PluginConfig> {
impl<PluginConfig, HttpCallArg> PluginHttpWrapper<PluginConfig, HttpCallArg> {
pub fn new(
rule_matcher: &SharedRuleMatcher<PluginConfig>,
http_content: Box<dyn HttpContextWrapper<PluginConfig>>,
http_content: Box<dyn HttpContextWrapper<PluginConfig, HttpCallArg>>,
) -> Self {
PluginHttpWrapper {
req_headers: MultiMap::new(),
@@ -100,8 +205,15 @@ impl<PluginConfig> PluginHttpWrapper<PluginConfig> {
http_content,
}
}
fn get_http_call_arg(&mut self, token_id: u32) -> Option<HttpCallArg> {
if let Some(storage) = self.http_content.get_http_call_storage() {
storage.pop(token_id)
} else {
None
}
}
}
impl<PluginConfig> Context for PluginHttpWrapper<PluginConfig> {
impl<PluginConfig, HttpCallArg> Context for PluginHttpWrapper<PluginConfig, HttpCallArg> {
fn on_http_call_response(
&mut self,
token_id: u32,
@@ -109,8 +221,50 @@ impl<PluginConfig> Context for PluginHttpWrapper<PluginConfig> {
body_size: usize,
num_trailers: usize,
) {
self.http_content
.on_http_call_response(token_id, num_headers, body_size, num_trailers)
if let Some(arg) = self.get_http_call_arg(token_id) {
let body = self.get_http_call_response_body(0, body_size);
let mut headers = MultiMap::new();
let mut status_code = 502;
let mut normal_response = false;
for (k, v) in self.get_http_call_response_headers_bytes() {
match String::from_utf8(v) {
Ok(header_value) => {
if k == ":status" {
if let Ok(code) = header_value.parse::<u16>() {
status_code = code;
normal_response = true;
} else {
self.http_content
.log()
.error(&format!("failed to parse status: {}", header_value));
status_code = 500;
}
}
headers.insert(k, header_value);
}
Err(_) => {
self.http_content.log().warn(&format!(
"http call response header contains non-ASCII characters header: {}",
k
));
}
}
}
self.http_content.log().warn(&format!(
"http call end, id: {}, code: {}, normal: {}, body: {:?}",
token_id, status_code, normal_response, body
));
self.http_content.on_http_call_response_detail(
token_id,
arg,
status_code,
&headers,
body,
)
} else {
self.http_content
.on_http_call_response(token_id, num_headers, body_size, num_trailers)
}
}
fn on_grpc_call_response(&mut self, token_id: u32, status_code: u32, response_size: usize) {
@@ -138,7 +292,7 @@ impl<PluginConfig> Context for PluginHttpWrapper<PluginConfig> {
self.http_content.on_done()
}
}
impl<PluginConfig> HttpContext for PluginHttpWrapper<PluginConfig>
impl<PluginConfig, HttpCallArg> HttpContext for PluginHttpWrapper<PluginConfig, HttpCallArg>
where
PluginConfig: Default + DeserializeOwned + Clone,
{
@@ -152,15 +306,10 @@ where
self.req_headers.insert(k, header_value);
}
Err(_) => {
log(
LogLevel::Warn,
format!(
"request http header contains non-ASCII characters header: {}",
k
)
.as_str(),
)
.unwrap();
self.http_content.log().warn(&format!(
"request http header contains non-ASCII characters header: {}",
k
));
}
}
}
@@ -212,15 +361,10 @@ where
self.res_headers.insert(k, header_value);
}
Err(_) => {
log(
LogLevel::Warn,
format!(
"response http header contains non-ASCII characters header: {}",
k
)
.as_str(),
)
.unwrap();
self.http_content.log().warn(&format!(
"response http header contains non-ASCII characters header: {}",
k
));
}
}
}