diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index f23ccb2..e9204d1 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -142,13 +142,9 @@ impl Context for Filter { match op_res { Ok(operation) => { - if GrpcService::process_grpc_response( - operation, - resp_size, - &mut self.response_headers_to_add, - ) - .is_ok() - { + if let Ok(result) = GrpcService::process_grpc_response(operation, resp_size) { + // add the response headers + self.response_headers_to_add.extend(result.response_headers); // call the next op match self.operation_dispatcher.borrow_mut().next() { Ok(some_op) => { diff --git a/src/service.rs b/src/service.rs index 617f4f7..bc2b117 100644 --- a/src/service.rs +++ b/src/service.rs @@ -54,8 +54,7 @@ impl GrpcService { pub fn process_grpc_response( operation: Rc, resp_size: usize, - response_headers_to_add: &mut Vec<(String, String)>, - ) -> Result<(), StatusCode> { + ) -> Result { let failure_mode = operation.get_failure_mode(); if let Some(res_body_bytes) = hostcalls::get_buffer(BufferType::GrpcReceiveBuffer, 0, resp_size).unwrap() @@ -63,11 +62,9 @@ impl GrpcService { match GrpcMessageResponse::new(operation.get_service_type(), &res_body_bytes) { Ok(res) => match operation.get_service_type() { ServiceType::Auth => AuthService::process_auth_grpc_response(res, failure_mode), - ServiceType::RateLimit => RateLimitService::process_ratelimit_grpc_response( - res, - failure_mode, - response_headers_to_add, - ), + ServiceType::RateLimit => { + RateLimitService::process_ratelimit_grpc_response(res, failure_mode) + } }, Err(e) => { warn!( @@ -95,6 +92,20 @@ impl GrpcService { } } +struct GrpcResult { + pub response_headers: Vec<(String, String)>, +} +impl GrpcResult { + pub fn default() -> Self { + Self { + response_headers: Vec::new(), + } + } + pub fn new(response_headers: Vec<(String, String)>) -> Self { + Self { response_headers } + } +} + pub type GrpcCallFn = fn( upstream_name: &str, service_name: &str, diff --git a/src/service/auth.rs b/src/service/auth.rs index a1e4e96..33af068 100644 --- a/src/service/auth.rs +++ b/src/service/auth.rs @@ -6,7 +6,7 @@ use crate::envoy::{ SocketAddress, StatusCode, }; use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult}; -use crate::service::GrpcService; +use crate::service::{GrpcResult, GrpcService}; use chrono::{DateTime, FixedOffset}; use log::{debug, warn}; use protobuf::well_known_types::Timestamp; @@ -125,7 +125,7 @@ impl AuthService { pub fn process_auth_grpc_response( auth_resp: GrpcMessageResponse, failure_mode: FailureMode, - ) -> Result<(), StatusCode> { + ) -> Result { if let GrpcMessageResponse::Auth(check_response) = auth_resp { // store dynamic metadata in filter state store_metadata(check_response.get_dynamic_metadata()); @@ -153,7 +153,7 @@ impl AuthService { ) .unwrap() }); - Ok(()) + Ok(GrpcResult::default()) } Some(CheckResponse_oneof_http_response::denied_response(denied_response)) => { debug!("process_auth_grpc_response: received DeniedHttpResponse"); diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index 275817a..4d8f242 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -3,7 +3,7 @@ use crate::envoy::{ RateLimitDescriptor, RateLimitRequest, RateLimitResponse, RateLimitResponse_Code, StatusCode, }; use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult}; -use crate::service::GrpcService; +use crate::service::{GrpcResult, GrpcService}; use log::warn; use protobuf::{Message, RepeatedField}; use proxy_wasm::hostcalls; @@ -38,8 +38,7 @@ impl RateLimitService { pub fn process_ratelimit_grpc_response( rl_resp: GrpcMessageResponse, failure_mode: FailureMode, - response_headers_to_add: &mut Vec<(String, String)>, - ) -> Result<(), StatusCode> { + ) -> Result { match rl_resp { GrpcMessageResponse::RateLimit(RateLimitResponse { overall_code: RateLimitResponse_Code::UNKNOWN, @@ -66,11 +65,13 @@ impl RateLimitService { response_headers_to_add: additional_headers, .. }) => { - additional_headers.iter().for_each(|header| { - response_headers_to_add - .push((header.get_key().to_owned(), header.get_value().to_owned())) - }); - Ok(()) + let result = GrpcResult::new( + additional_headers + .iter() + .map(|header| (header.get_key().to_owned(), header.get_value().to_owned())) + .collect(), + ); + Ok(result) } _ => { warn!("not a valid GrpcMessageResponse::RateLimit(RateLimitResponse)!");