From 76c9bea1926f7b82a23d57e8dca8e660b6fad8f3 Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Tue, 1 Oct 2024 15:47:03 +0200 Subject: [PATCH 1/2] [refactor] OperationDispatcher not using RefCell for storing operations * Operations store its status and result in RefCell for interior mut * OperationDispatcher keeps a Vec of Rc, then indexes cloning the Rc instead of cloning the entire object. Signed-off-by: dd di cesare --- src/operation_dispatcher.rs | 180 +++++++++++++++++------------------- 1 file changed, 84 insertions(+), 96 deletions(-) diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs index 792bbca8..65b9d064 100644 --- a/src/operation_dispatcher.rs +++ b/src/operation_dispatcher.rs @@ -5,13 +5,13 @@ use crate::service::{GetMapValuesBytesFn, GrpcCallFn, GrpcMessageBuildFn, GrpcSe use log::error; use proxy_wasm::hostcalls; use proxy_wasm::types::{Bytes, MapType, Status}; -use std::cell::{RefCell, RefMut}; +use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; use std::time::Duration; #[allow(dead_code)] -#[derive(PartialEq, Debug, Clone)] +#[derive(PartialEq, Debug, Clone, Copy)] pub(crate) enum State { Pending, Waiting, @@ -36,8 +36,8 @@ impl State { #[allow(dead_code)] #[derive(Clone)] pub(crate) struct Operation { - state: State, - result: Result, + state: RefCell, + result: RefCell>, extension: Rc, action: Action, service: Rc, @@ -50,8 +50,8 @@ pub(crate) struct Operation { impl Operation { pub fn new(extension: Rc, action: Action, service: Rc) -> Self { Self { - state: State::Pending, - result: Ok(0), // Heuristics: zero represents that it's not been triggered, following `hostcalls` example + state: RefCell::new(State::Pending), + result: RefCell::new(Ok(0)), // Heuristics: zero represents that it's not been triggered, following `hostcalls` example extension, action, service, @@ -61,37 +61,50 @@ impl Operation { } } - fn trigger(&mut self) -> Result { - match self.state { + fn trigger(&self) -> Result { + match self.get_state() { State::Pending => { if let Some(message) = (self.grpc_message_build_fn)(self.get_extension_type(), &self.action) { - self.result = + let res = self.service .send(self.get_map_values_bytes_fn, self.grpc_call_fn, message); - self.state.next(); - self.result + self.set_result(res); + self.next_state(); + res } else { //todo: we need to move to and start the next action - self.state.done(); - Ok(1234) + self.done(); + self.get_result() } } State::Waiting => { - self.state.next(); - self.result + self.next_state(); + self.get_result() } - State::Done => self.result, + State::Done => self.get_result(), } } - pub fn get_state(&self) -> &State { - &self.state + fn next_state(&self) { + self.state.borrow_mut().next() + } + + fn done(&self) { + self.state.borrow_mut().done() + } + + pub fn get_state(&self) -> State { + *self.state.borrow() } pub fn get_result(&self) -> Result { - self.result + *self.result.borrow() + } + + fn set_result(&self, result: Result) { + *self.result.borrow_mut() = result; } pub fn get_extension_type(&self) -> &ExtensionType { @@ -105,8 +118,8 @@ impl Operation { #[allow(dead_code)] pub struct OperationDispatcher { - operations: RefCell>, - waiting_operations: RefCell>, // TODO(didierofrivia): Maybe keep references or Rc + operations: Vec>, + waiting_operations: HashMap>, service_handlers: HashMap>, } @@ -114,60 +127,50 @@ pub struct OperationDispatcher { impl OperationDispatcher { pub fn default() -> Self { OperationDispatcher { - operations: RefCell::new(vec![]), - waiting_operations: RefCell::new(HashMap::default()), + operations: vec![], + waiting_operations: HashMap::default(), service_handlers: HashMap::default(), } } pub fn new(service_handlers: HashMap>) -> Self { Self { service_handlers, - operations: RefCell::new(vec![]), - waiting_operations: RefCell::new(HashMap::new()), + operations: vec![], + waiting_operations: HashMap::new(), } } - pub fn get_operation(&self, token_id: u32) -> Option { - self.waiting_operations.borrow_mut().get(&token_id).cloned() + pub fn get_operation(&self, token_id: u32) -> Option> { + self.waiting_operations.get(&token_id).cloned() } - pub fn build_operations(&self, rule: &Rule) { - let mut operations: Vec = vec![]; + pub fn build_operations(&mut self, rule: &Rule) { + let mut operations: Vec> = vec![]; for action in rule.actions.iter() { // TODO(didierofrivia): Error handling if let Some(service) = self.service_handlers.get(&action.extension) { - operations.push(Operation::new( + operations.push(Rc::new(Operation::new( service.get_extension(), action.clone(), Rc::clone(service), - )) + ))) } } self.push_operations(operations); } - pub fn push_operations(&self, operations: Vec) { - self.operations.borrow_mut().extend(operations); + pub fn push_operations(&mut self, operations: Vec>) { + self.operations.extend(operations); } pub fn get_current_operation_state(&self) -> Option { self.operations - .borrow() .first() - .map(|operation| operation.get_state().clone()) + .map(|operation| operation.get_state()) } - pub fn get_current_operation_result(&self) -> Result { - self.operations.borrow().first().unwrap().get_result() - } - - pub fn next(&self) -> Option { - let operations = self.operations.borrow_mut(); - self.step(operations) - } - - fn step(&self, mut operations: RefMut>) -> Option { - if let Some((i, operation)) = operations.iter_mut().enumerate().next() { + pub fn next(&mut self) -> Option> { + if let Some((i, operation)) = self.operations.iter_mut().enumerate().next() { match operation.get_state() { State::Pending => { match operation.trigger() { @@ -178,13 +181,11 @@ impl OperationDispatcher { } State::Waiting => { // We index only if it was just transitioned to Waiting after triggering - self.waiting_operations - .borrow_mut() - .insert(token_id, operation.clone()); + self.waiting_operations.insert(token_id, operation.clone()); // TODO(didierofrivia): Decide on indexing the failed operations. Some(operation.clone()) } - State::Done => self.step(operations), + State::Done => self.next(), } } Err(status) => { @@ -198,11 +199,11 @@ impl OperationDispatcher { Some(operation.clone()) } State::Done => { - if let Ok(token_id) = operation.result { - self.waiting_operations.borrow_mut().remove(&token_id); + if let Ok(token_id) = operation.get_result() { + self.waiting_operations.remove(&token_id); } // If result was Err, means the operation wasn't indexed - operations.remove(i); - self.step(operations) + self.operations.remove(i); + self.next() } } } else { @@ -286,10 +287,13 @@ mod tests { } } - fn build_operation(grpc_call_fn_stub: GrpcCallFn, extension_type: ExtensionType) -> Operation { - Operation { - state: State::Pending, - result: Ok(0), + fn build_operation( + grpc_call_fn_stub: GrpcCallFn, + extension_type: ExtensionType, + ) -> Rc { + Rc::new(Operation { + state: RefCell::from(State::Pending), + result: RefCell::new(Ok(0)), extension: Rc::new(Extension { extension_type, endpoint: "local".to_string(), @@ -304,14 +308,14 @@ mod tests { grpc_call_fn: grpc_call_fn_stub, get_map_values_bytes_fn: get_map_values_bytes_fn_stub, grpc_message_build_fn: grpc_message_build_fn_stub, - } + }) } #[test] fn operation_getters() { let operation = build_operation(default_grpc_call_fn_stub, ExtensionType::RateLimit); - assert_eq!(*operation.get_state(), State::Pending); + assert_eq!(operation.get_state(), State::Pending); assert_eq!(*operation.get_extension_type(), ExtensionType::RateLimit); assert_eq!(*operation.get_failure_mode(), FailureMode::Deny); assert_eq!(operation.get_result(), Ok(0)); @@ -319,34 +323,34 @@ mod tests { #[test] fn operation_transition() { - let mut operation = build_operation(default_grpc_call_fn_stub, ExtensionType::RateLimit); - assert_eq!(operation.result, Ok(0)); - assert_eq!(*operation.get_state(), State::Pending); + let operation = build_operation(default_grpc_call_fn_stub, ExtensionType::RateLimit); + assert_eq!(operation.get_result(), Ok(0)); + assert_eq!(operation.get_state(), State::Pending); let mut res = operation.trigger(); assert_eq!(res, Ok(200)); - assert_eq!(*operation.get_state(), State::Waiting); + assert_eq!(operation.get_state(), State::Waiting); res = operation.trigger(); assert_eq!(res, Ok(200)); - assert_eq!(operation.result, Ok(200)); - assert_eq!(*operation.get_state(), State::Done); + assert_eq!(operation.get_result(), Ok(200)); + assert_eq!(operation.get_state(), State::Done); } #[test] fn operation_dispatcher_push_actions() { - let operation_dispatcher = OperationDispatcher::default(); + let mut operation_dispatcher = OperationDispatcher::default(); - assert_eq!(operation_dispatcher.operations.borrow().len(), 0); + assert_eq!(operation_dispatcher.operations.len(), 0); operation_dispatcher.push_operations(vec![build_operation( default_grpc_call_fn_stub, ExtensionType::RateLimit, )]); - assert_eq!(operation_dispatcher.operations.borrow().len(), 1); + assert_eq!(operation_dispatcher.operations.len(), 1); } #[test] fn operation_dispatcher_get_current_action_state() { - let operation_dispatcher = OperationDispatcher::default(); + let mut operation_dispatcher = OperationDispatcher::default(); operation_dispatcher.push_operations(vec![build_operation( default_grpc_call_fn_stub, ExtensionType::RateLimit, @@ -359,7 +363,7 @@ mod tests { #[test] fn operation_dispatcher_next() { - let operation_dispatcher = OperationDispatcher::default(); + let mut operation_dispatcher = OperationDispatcher::default(); fn grpc_call_fn_stub_66( _upstream_name: &str, @@ -388,15 +392,11 @@ mod tests { build_operation(grpc_call_fn_stub_77, ExtensionType::Auth), ]); - assert_eq!(operation_dispatcher.get_current_operation_result(), Ok(0)); assert_eq!( operation_dispatcher.get_current_operation_state(), Some(State::Pending) ); - assert_eq!( - operation_dispatcher.waiting_operations.borrow_mut().len(), - 0 - ); + assert_eq!(operation_dispatcher.waiting_operations.len(), 0); let mut op = operation_dispatcher.next(); assert_eq!(op.clone().unwrap().get_result(), Ok(66)); @@ -404,15 +404,12 @@ mod tests { *op.clone().unwrap().get_extension_type(), ExtensionType::RateLimit ); - assert_eq!(*op.unwrap().get_state(), State::Waiting); - assert_eq!( - operation_dispatcher.waiting_operations.borrow_mut().len(), - 1 - ); + assert_eq!(op.unwrap().get_state(), State::Waiting); + assert_eq!(operation_dispatcher.waiting_operations.len(), 1); op = operation_dispatcher.next(); assert_eq!(op.clone().unwrap().get_result(), Ok(66)); - assert_eq!(*op.unwrap().get_state(), State::Done); + assert_eq!(op.unwrap().get_state(), State::Done); op = operation_dispatcher.next(); assert_eq!(op.clone().unwrap().get_result(), Ok(77)); @@ -420,26 +417,17 @@ mod tests { *op.clone().unwrap().get_extension_type(), ExtensionType::Auth ); - assert_eq!(*op.unwrap().get_state(), State::Waiting); - assert_eq!( - operation_dispatcher.waiting_operations.borrow_mut().len(), - 1 - ); + assert_eq!(op.unwrap().get_state(), State::Waiting); + assert_eq!(operation_dispatcher.waiting_operations.len(), 1); op = operation_dispatcher.next(); assert_eq!(op.clone().unwrap().get_result(), Ok(77)); - assert_eq!(*op.unwrap().get_state(), State::Done); - assert_eq!( - operation_dispatcher.waiting_operations.borrow_mut().len(), - 1 - ); + assert_eq!(op.unwrap().get_state(), State::Done); + assert_eq!(operation_dispatcher.waiting_operations.len(), 1); op = operation_dispatcher.next(); assert!(op.is_none()); assert!(operation_dispatcher.get_current_operation_state().is_none()); - assert_eq!( - operation_dispatcher.waiting_operations.borrow_mut().len(), - 0 - ); + assert_eq!(operation_dispatcher.waiting_operations.len(), 0); } } From 249d3cd1c1055d47238d7475f128ada5d66073b9 Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Tue, 1 Oct 2024 15:49:39 +0200 Subject: [PATCH 2/2] [refactor] OperationDispatcher within a RefCell for interior mut Signed-off-by: dd di cesare --- src/filter/http_context.rs | 21 +++++++++++++-------- src/filter/root_context.rs | 2 +- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index e28e5f6d..cbc7eea9 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -7,13 +7,14 @@ use crate::service::grpc_message::GrpcMessageResponse; use log::{debug, warn}; use proxy_wasm::traits::{Context, HttpContext}; use proxy_wasm::types::Action; +use std::cell::RefCell; use std::rc::Rc; pub struct Filter { pub context_id: u32, pub config: Rc, pub response_headers_to_add: Vec<(String, String)>, - pub operation_dispatcher: OperationDispatcher, + pub operation_dispatcher: RefCell, } impl Filter { @@ -32,13 +33,15 @@ impl Filter { fn process_policy(&self, policy: &Policy) -> Action { if let Some(rule) = policy.find_rule_that_applies() { - self.operation_dispatcher.build_operations(rule); + self.operation_dispatcher + .borrow_mut() + .build_operations(rule); } else { debug!("#{} process_policy: no rule applied", self.context_id); return Action::Continue; } - if let Some(operation) = self.operation_dispatcher.next() { + if let Some(operation) = self.operation_dispatcher.borrow_mut().next() { match operation.get_result() { Ok(call_id) => { debug!("#{} initiated gRPC call (id# {})", self.context_id, call_id); @@ -101,11 +104,11 @@ impl Filter { } _ => {} } - self.operation_dispatcher.next(); + self.operation_dispatcher.borrow_mut().next(); } fn process_auth_grpc_response( - &mut self, + &self, auth_resp: GrpcMessageResponse, failure_mode: &FailureMode, ) { @@ -156,7 +159,7 @@ impl Filter { } } } - self.operation_dispatcher.next(); + self.operation_dispatcher.borrow_mut().next(); } } @@ -203,7 +206,9 @@ impl Context for Filter { self.context_id ); - if let Some(operation) = self.operation_dispatcher.get_operation(token_id) { + let some_op = self.operation_dispatcher.borrow().get_operation(token_id); + + if let Some(operation) = some_op { let failure_mode = &operation.get_failure_mode(); let res_body_bytes = match self.get_grpc_call_response_body(0, resp_size) { Some(bytes) => bytes, @@ -229,7 +234,7 @@ impl Context for Filter { ExtensionType::RateLimit => self.process_ratelimit_grpc_response(res, failure_mode), } - if let Some(_op) = self.operation_dispatcher.next() { + if let Some(_op) = self.operation_dispatcher.borrow_mut().next() { } else { self.resume_http_request() } diff --git a/src/filter/root_context.rs b/src/filter/root_context.rs index 6dcd4bdc..5e5a8aa6 100644 --- a/src/filter/root_context.rs +++ b/src/filter/root_context.rs @@ -55,7 +55,7 @@ impl RootContext for FilterRoot { context_id, config: Rc::clone(&self.config), response_headers_to_add: Vec::default(), - operation_dispatcher: OperationDispatcher::new(service_handlers), + operation_dispatcher: OperationDispatcher::new(service_handlers).into(), })) }