Skip to content

Commit

Permalink
Replace NodeState with MonadState
Browse files Browse the repository at this point in the history
Previously, we need a separate `HasNodeState` class and `NodeStateT` monad
transformer because there was already a `StateT` in our stack, i.e., the one
containing the DRG. As that is gone, we can switch back to a regular
`MonadState` + `StateT`.

The cost is an orphan instance: `MonadRandom (StateT s m)`
  • Loading branch information
mrBliss committed Feb 20, 2020
1 parent 99787fa commit 49a35f6
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import qualified Cardano.Chain.UTxO as Impl

import Ouroboros.Network.Block

import Ouroboros.Consensus.Block
import Ouroboros.Consensus.Config
import Ouroboros.Consensus.Ledger.Dual
import Ouroboros.Consensus.Ledger.Extended
Expand Down Expand Up @@ -208,7 +209,7 @@ bridgeTransactionIds = Spec.Test.transactionIds

forgeDualByronBlock
:: forall m.
( HasNodeState_ () m -- @()@ is the @NodeState@ of PBFT
( HasNodeState (BlockProtocol DualByronBlock) m
, MonadRandom m
, HasCallStack
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ import Ouroboros.Consensus.Byron.Protocol

forgeByronBlock
:: forall m.
( HasNodeState_ () m -- @()@ is the @NodeState@ of PBFT
( HasNodeState (BlockProtocol ByronBlock) m
, MonadRandom m
, HasCallStack
)
Expand Down Expand Up @@ -129,7 +129,7 @@ initBlockPayloads = BlockPayloads

forgeRegularBlock
:: forall m.
( HasNodeState_ () m -- @()@ is the @NodeState@ of PBFT
( HasNodeState (BlockProtocol ByronBlock) m
, MonadRandom m
, HasCallStack
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import Codec.Serialise (Serialise (..))
import Control.Monad (unless)
import Control.Monad.Except (throwError)
import Control.Monad.Identity (runIdentity)
import Control.Monad.State (get, put)
import Crypto.Random (MonadRandom)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
Expand Down Expand Up @@ -125,7 +126,7 @@ forgePraosFields :: ( HasNodeState (Praos c) m
-> (PraosExtraFields c -> toSign)
-> m (PraosFields c toSign)
forgePraosFields PraosNodeConfig{..} PraosProof{..} mkToSign = do
keyKES <- unPraosNodeState <$> getNodeState
keyKES <- unPraosNodeState <$> get
let signedFields = PraosExtraFields {
praosCreator = praosLeader
, praosRho = praosProofRho
Expand All @@ -142,7 +143,7 @@ forgePraosFields PraosNodeConfig{..} PraosProof{..} mkToSign = do
-- TODO : We should not update the key on each signing, but X slots
-- (for configurable param X)
newKey <- fromMaybe (error "mkOutoborosPayload: updateKES failed") <$> updateKES () keyKES
putNodeState (PraosNodeState newKey)
put (PraosNodeState newKey)
return $ PraosFields {
praosSignature = signature
, praosExtraFields = signedFields
Expand Down
3 changes: 2 additions & 1 deletion ouroboros-consensus/src/Ouroboros/Consensus/NodeKernel.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ module Ouroboros.Consensus.NodeKernel (
) where

import Control.Monad
import Control.Monad.State (runStateT)
import Data.Map.Strict (Map)
import Data.Maybe (isJust, isNothing)
import Data.Proxy
Expand Down Expand Up @@ -504,7 +505,7 @@ forkBlockProduction maxBlockSizeOverride IS{..} BlockProduction{..} =

(a, nodeState') <-
runMonadRandom runMonadRandomDict $
runNodeStateT n nodeState
runStateT n nodeState

atomically $ writeTVar varState nodeState'
return a
Expand Down
62 changes: 2 additions & 60 deletions ouroboros-consensus/src/Ouroboros/Consensus/Protocol/Abstract.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,13 @@ module Ouroboros.Consensus.Protocol.Abstract (
ConsensusProtocol(..)
, NodeConfig
, SecurityParam(..)
-- * State monad for Ouroboros state
, HasNodeState
, HasNodeState_(..)
, NodeStateT
, NodeStateT_ -- opaque
, nodeStateT
, runNodeStateT
, evalNodeStateT
, runNodeState
, evalNodeState
) where

import Codec.Serialise (Serialise)
import Control.Monad.Except
import Control.Monad.State
import Crypto.Random (MonadRandom (..))
import Data.Functor.Identity
import Data.Typeable (Typeable)
import Data.Word (Word64)
import GHC.Generics (Generic)
Expand All @@ -44,8 +34,6 @@ import Cardano.Prelude (NoUnexpectedThunks)
import Ouroboros.Network.Block (BlockNo, HeaderHash, Point,
SlotNo (..))

import Ouroboros.Consensus.Util.Random

-- | Static node configuration
--
-- Every method in the 'ConsensusProtocol' class takes the node configuration as
Expand Down Expand Up @@ -238,51 +226,5 @@ class ( Show (ChainState p)
newtype SecurityParam = SecurityParam { maxRollbacks :: Word64 }
deriving (Show, Eq, Generic, NoUnexpectedThunks)

{-------------------------------------------------------------------------------
State monad
-------------------------------------------------------------------------------}

type HasNodeState p = HasNodeState_ (NodeState p)

-- | State monad for the Ouroboros specific state
--
-- We introduce this so that we can have both MonadState and OuroborosState
-- in a monad stack.
class Monad m => HasNodeState_ s m | m -> s where
getNodeState :: m s
putNodeState :: s -> m ()

instance HasNodeState_ s m => HasNodeState_ s (ChaChaT m) where
getNodeState = lift $ getNodeState
putNodeState = lift . putNodeState

{-------------------------------------------------------------------------------
Monad transformer introducing 'HasNodeState_'
-------------------------------------------------------------------------------}

newtype NodeStateT_ s m a = NodeStateT { unNodeStateT :: StateT s m a }
deriving (Functor, Applicative, Monad, MonadTrans)

type NodeStateT p = NodeStateT_ (NodeState p)

nodeStateT :: (s -> m (a, s)) -> NodeStateT_ s m a
nodeStateT = NodeStateT . StateT

runNodeStateT :: NodeStateT_ s m a -> s -> m (a, s)
runNodeStateT = runStateT . unNodeStateT

evalNodeStateT :: Monad m => NodeStateT_ s m a -> s -> m a
evalNodeStateT act = fmap fst . runNodeStateT act

runNodeState :: (forall m. Monad m => NodeStateT_ s m a) -> s -> (a, s)
runNodeState act = runIdentity . runNodeStateT act

evalNodeState :: (forall m. Monad m => NodeStateT_ s m a) -> s -> a
evalNodeState act = fst . runNodeState act

instance Monad m => HasNodeState_ s (NodeStateT_ s m) where
getNodeState = NodeStateT $ get
putNodeState = NodeStateT . put

instance MonadRandom m => MonadRandom (NodeStateT_ s m) where
getRandomBytes = lift . getRandomBytes
-- | Short-hand
type HasNodeState p = MonadState (NodeState p)
5 changes: 5 additions & 0 deletions ouroboros-consensus/src/Ouroboros/Consensus/Util/Random.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

{-# OPTIONS_GHC -Wno-orphans #-}
module Ouroboros.Consensus.Util.Random (
-- * Producing values in MonadRandom
generateElement
Expand Down Expand Up @@ -84,3 +85,7 @@ newtype RunMonadRandom m = RunMonadRandom
-- | Use the 'MonadRandom' instance for 'IO'.
runMonadRandomIO :: RunMonadRandom IO
runMonadRandomIO = RunMonadRandom id

-- | Handy during block production
instance MonadRandom m => MonadRandom (StateT s m) where
getRandomBytes = lift . getRandomBytes
25 changes: 0 additions & 25 deletions ouroboros-consensus/src/Ouroboros/Consensus/Util/STM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,15 @@ module Ouroboros.Consensus.Util.STM (
, Sim(..)
, simId
, simStateT
, simOuroborosStateT
, simChaChaT
) where

import Control.Monad.State
import Data.Coerce
import Data.Void
import Data.Word (Word64)
import GHC.Generics (Generic)
import GHC.Stack

import Ouroboros.Consensus.Protocol.Abstract
import Ouroboros.Consensus.Util.IOLike
import Ouroboros.Consensus.Util.Random
import Ouroboros.Consensus.Util.ResourceRegistry

{-------------------------------------------------------------------------------
Expand Down Expand Up @@ -135,23 +130,3 @@ simStateT stVar (Sim k) = Sim $ \(StateT f) -> do
(a, st') <- k (f st)
writeTVar stVar st'
return a

simOuroborosStateT :: IOLike m
=> StrictTVar m s
-> Sim n m
-> Sim (NodeStateT_ s n) m
simOuroborosStateT stVar (Sim k) = Sim $ \n -> do
st <- readTVar stVar
(a, st') <- k (runNodeStateT n st)
writeTVar stVar st'
return a

simChaChaT :: (IOLike m, Coercible a ChaChaDRG)
=> StrictTVar m a
-> Sim n m
-> Sim (ChaChaT n) m
simChaChaT stVar (Sim k) = Sim $ \n -> do
st <- readTVar stVar
(a, st') <- k (runChaChaT n (coerce st))
writeTVar stVar (coerce st')
return a

0 comments on commit 49a35f6

Please sign in to comment.