From f12048a2a406aa05915b23941171e6c7308d93ac Mon Sep 17 00:00:00 2001 From: bear Date: Mon, 13 Jun 2022 16:32:59 +0800 Subject: [PATCH] Enhance dispatch module (#121) --- modules/dispatch/src/lib.rs | 64 +++++++++++++++++--------- primitives/message-dispatch/src/lib.rs | 19 +++++++- 2 files changed, 60 insertions(+), 23 deletions(-) diff --git a/modules/dispatch/src/lib.rs b/modules/dispatch/src/lib.rs index 56b3c7a60..252894ff9 100644 --- a/modules/dispatch/src/lib.rs +++ b/modules/dispatch/src/lib.rs @@ -25,7 +25,9 @@ // Generated by `decl_event!` #![allow(clippy::unused_unit)] -use bp_message_dispatch::{CallOrigin, MessageDispatch, MessagePayload, SpecVersion}; +use bp_message_dispatch::{ + CallFilter, CallOrigin, IntoDispatchOrigin, MessageDispatch, MessagePayload, SpecVersion, +}; use bp_runtime::{ derive_account_id, messages::{DispatchFeePayment, MessageDispatchResult}, @@ -35,7 +37,7 @@ use codec::Encode; use frame_support::{ dispatch::Dispatchable, ensure, - traits::{Contains, Get}, + traits::Get, weights::{extract_actual_weight, GetDispatchInfo}, }; use frame_system::RawOrigin; @@ -81,7 +83,7 @@ pub mod pallet { /// /// The pallet will filter all incoming calls right before they're dispatched. If this /// filter rejects the call, special event (`Event::MessageCallRejected`) is emitted. - type CallFilter: Contains<>::Call>; + type CallFilter: CallFilter>::Call>; /// The type that is used to wrap the `Self::Call` when it is moved over bridge. /// /// The idea behind this is to avoid `Call` conversion/decoding until we'll be sure @@ -93,6 +95,12 @@ pub mod pallet { /// /// Used when deriving target chain AccountIds from source chain AccountIds. type AccountIdConverter: sp_runtime::traits::Convert; + /// The type is used to customize the dispatch call origin. + type IntoDispatchOrigin: IntoDispatchOrigin< + Self::AccountId, + >::Call, + Self::Origin, + >; } type BridgeMessageIdOf = >::BridgeMessageId; @@ -138,7 +146,9 @@ pub mod pallet { } } -impl, I: 'static> MessageDispatch for Pallet { +impl, I: 'static> + MessageDispatch>::Call> for Pallet +{ type Message = MessagePayload< T::SourceChainAccountId, T::TargetChainAccountPublic, @@ -150,7 +160,7 @@ impl, I: 'static> MessageDispatch message.weight } - fn dispatch Result<(), ()>>( + fn dispatch>::Call) -> Result<(), ()>>( source_chain: ChainId, target_chain: ChainId, id: T::BridgeMessageId, @@ -217,7 +227,7 @@ impl, I: 'static> MessageDispatch }; // prepare dispatch origin - let origin_account = match message.origin { + let origin_derived_account = match message.origin { CallOrigin::SourceRoot => { let hex_id = derive_account_id::(source_chain, SourceAccount::Root); @@ -260,8 +270,12 @@ impl, I: 'static> MessageDispatch }, }; + // generate dispatch origin from origin account + let dispatch_origin = + T::IntoDispatchOrigin::into_dispatch_origin(&origin_derived_account, &call); + // filter the call - if !T::CallFilter::contains(&call) { + if !T::CallFilter::contains(&dispatch_origin, &call) { log::trace!( target: "runtime::bridge-dispatch", "Message {:?}/{:?}: the call ({:?}) is rejected by filter", @@ -299,9 +313,7 @@ impl, I: 'static> MessageDispatch // pay dispatch fee right before dispatch let pay_dispatch_fee_at_target_chain = message.dispatch_fee_payment == DispatchFeePayment::AtTargetChain; - if pay_dispatch_fee_at_target_chain - && pay_dispatch_fee(&origin_account, message.weight).is_err() - { + if pay_dispatch_fee_at_target_chain && pay_dispatch_fee(&dispatch_origin, &call).is_err() { log::trace!( target: "runtime::bridge-dispatch", "Failed to pay dispatch fee for dispatching message {:?}/{:?} with weight {}", @@ -312,18 +324,15 @@ impl, I: 'static> MessageDispatch Self::deposit_event(Event::MessageDispatchPaymentFailed( source_chain, id, - origin_account, + origin_derived_account, message.weight, )); return dispatch_result; } dispatch_result.dispatch_fee_paid_during_dispatch = pay_dispatch_fee_at_target_chain; - // finally dispatch message - let origin = RawOrigin::Signed(origin_account).into(); - log::trace!(target: "runtime::bridge-dispatch", "Message being dispatched is: {:.4096?}", &call); - let result = call.dispatch(origin); + let result = call.dispatch(dispatch_origin); let actual_call_weight = extract_actual_weight(&result, &dispatch_info); dispatch_result.dispatch_result = result.is_ok(); dispatch_result.unspent_weight = message.weight.saturating_sub(actual_call_weight); @@ -529,6 +538,7 @@ mod tests { type CallFilter = TestCallFilter; type EncodedCall = EncodedCall; type Event = Event; + type IntoDispatchOrigin = TestIntoDispatchOrigin; type SourceChainAccountId = AccountId; type TargetChainAccountPublic = TestAccountPublic; type TargetChainSignature = TestSignature; @@ -545,12 +555,20 @@ mod tests { pub struct TestCallFilter; - impl Contains for TestCallFilter { - fn contains(call: &Call) -> bool { + impl CallFilter for TestCallFilter { + fn contains(_origin: &Origin, call: &Call) -> bool { !matches!(*call, Call::System(frame_system::Call::fill_block { .. })) } } + pub struct TestIntoDispatchOrigin; + + impl IntoDispatchOrigin for TestIntoDispatchOrigin { + fn into_dispatch_origin(id: &AccountId, _call: &Call) -> Origin { + frame_system::RawOrigin::Signed(*id).into() + } + } + const TEST_SPEC_VERSION: SpecVersion = 0; const TEST_WEIGHT: Weight = 1_000_000_000; @@ -563,8 +581,9 @@ mod tests { origin: CallOrigin, call: Call, ) -> as MessageDispatch< - AccountId, + ::Origin, ::BridgeMessageId, + ::Call, >>::Message { MessagePayload { spec_version: TEST_SPEC_VERSION, @@ -578,8 +597,9 @@ mod tests { fn prepare_root_message( call: Call, ) -> as MessageDispatch< - AccountId, + ::Origin, ::BridgeMessageId, + ::Call, >>::Message { prepare_message(CallOrigin::SourceRoot, call) } @@ -587,8 +607,9 @@ mod tests { fn prepare_target_message( call: Call, ) -> as MessageDispatch< - AccountId, + ::Origin, ::BridgeMessageId, + ::Call, >>::Message { let origin = CallOrigin::TargetAccount(1, TestAccountPublic(1), TestSignature(1)); prepare_message(origin, call) @@ -597,8 +618,9 @@ mod tests { fn prepare_source_message( call: Call, ) -> as MessageDispatch< - AccountId, + ::Origin, ::BridgeMessageId, + ::Call, >>::Message { let origin = CallOrigin::SourceAccount(1); prepare_message(origin, call) diff --git a/primitives/message-dispatch/src/lib.rs b/primitives/message-dispatch/src/lib.rs index 07e448ee7..2e69d7881 100644 --- a/primitives/message-dispatch/src/lib.rs +++ b/primitives/message-dispatch/src/lib.rs @@ -35,7 +35,7 @@ pub type Weight = u64; pub type SpecVersion = u32; /// A generic trait to dispatch arbitrary messages delivered over the bridge. -pub trait MessageDispatch { +pub trait MessageDispatch { /// A type of the message to be dispatched. type Message: codec::Decode; @@ -58,7 +58,7 @@ pub trait MessageDispatch { /// the whole message). /// /// Returns unspent dispatch weight. - fn dispatch Result<(), ()>>( + fn dispatch Result<(), ()>>( source_chain: ChainId, target_chain: ChainId, id: BridgeMessageId, @@ -140,3 +140,18 @@ impl Size self.call.len() as _ } } + +/// Customize the dispatch origin before call dispatch. +pub trait IntoDispatchOrigin { + /// Generate the dispatch origin for the given call. + /// + /// Normally, the dispatch origin is one kind of frame_system::RawOrigin, however, sometimes + /// it is useful for a dispatch call with a custom origin. + fn into_dispatch_origin(id: &AccountId, call: &Call) -> Origin; +} + +/// A generic trait to filter calls that are allowed to be dispatched. +pub trait CallFilter { + /// Filter the call, you might need origin to in the filter. return false, if not allowed. + fn contains(origin: &Origin, call: &Call) -> bool; +}