diff --git a/examples/block-download/src/main.rs b/examples/block-download/src/main.rs index 82000b7b..90c1bbbd 100644 --- a/examples/block-download/src/main.rs +++ b/examples/block-download/src/main.rs @@ -25,8 +25,8 @@ fn main() { let bearer = Bearer::connect_tcp("relays-new.cardano-mainnet.iohk.io:3001").unwrap(); let mut plexer = StdPlexer::new(bearer); - let mut channel0 = plexer.use_channel(0); - let mut channel3 = plexer.use_channel(3); + let mut channel0 = plexer.use_channel(0).into(); + let mut channel3 = plexer.use_channel(3).into(); plexer.muxer.spawn(); plexer.demuxer.spawn(); diff --git a/examples/n2c-miniprotocols/src/main.rs b/examples/n2c-miniprotocols/src/main.rs index b1b5b7f3..dc361628 100644 --- a/examples/n2c-miniprotocols/src/main.rs +++ b/examples/n2c-miniprotocols/src/main.rs @@ -43,12 +43,12 @@ impl chainsync::Observer for LoggingObserver { } } -fn do_handshake(mut channel: multiplexer::StdChannel) { +fn do_handshake(mut channel: multiplexer::StdChannelBuffer) { let versions = handshake::n2c::VersionTable::v1_and_above(MAINNET_MAGIC); let _last = run_agent(handshake::Initiator::initial(versions), &mut channel).unwrap(); } -fn do_localstate_query(mut channel: multiplexer::StdChannel) { +fn do_localstate_query(mut channel: multiplexer::StdChannelBuffer) { let agent = run_agent( localstate::OneShotClient::::initial( None, @@ -60,7 +60,7 @@ fn do_localstate_query(mut channel: multiplexer::StdChannel) { log::info!("state query result: {:?}", agent); } -fn do_chainsync(mut channel: multiplexer::StdChannel) { +fn do_chainsync(mut channel: multiplexer::StdChannelBuffer) { let known_points = vec![Point::Specific( 43847831u64, hex::decode("15b9eeee849dd6386d3770b0745e0450190f7560e5159b1b3ab13b14b2684a45").unwrap(), @@ -89,9 +89,9 @@ fn main() { // setup the multiplexer by specifying the bearer and the IDs of the // miniprotocols to use let mut plexer = multiplexer::StdPlexer::new(bearer); - let channel0 = plexer.use_channel(0); - let channel7 = plexer.use_channel(7); - let channel5 = plexer.use_channel(5); + let channel0 = plexer.use_channel(0).into(); + let channel7 = plexer.use_channel(7).into(); + let channel5 = plexer.use_channel(5).into(); plexer.muxer.spawn(); plexer.demuxer.spawn(); diff --git a/examples/n2n-miniprotocols/src/main.rs b/examples/n2n-miniprotocols/src/main.rs index 44191824..44a34d03 100644 --- a/examples/n2n-miniprotocols/src/main.rs +++ b/examples/n2n-miniprotocols/src/main.rs @@ -1,6 +1,6 @@ use pallas::network::{ miniprotocols::{blockfetch, chainsync, handshake, run_agent, Point, MAINNET_MAGIC}, - multiplexer::{bearers::Bearer, StdChannel, StdPlexer}, + multiplexer::{agents::ChannelBuffer, bearers::Bearer, StdChannel, StdPlexer}, }; #[derive(Debug)] @@ -50,12 +50,12 @@ impl chainsync::Observer for LoggingObserver { } } -fn do_handshake(mut channel: StdChannel) { +fn do_handshake(mut channel: ChannelBuffer) { let versions = handshake::n2n::VersionTable::v4_and_above(MAINNET_MAGIC); let _last = run_agent(handshake::Initiator::initial(versions), &mut channel).unwrap(); } -fn do_blockfetch(mut channel: StdChannel) { +fn do_blockfetch(mut channel: ChannelBuffer) { let range = ( Point::Specific( 43847831, @@ -77,7 +77,7 @@ fn do_blockfetch(mut channel: StdChannel) { println!("{:?}", agent); } -fn do_chainsync(mut channel: StdChannel) { +fn do_chainsync(mut channel: ChannelBuffer) { let known_points = vec![Point::Specific( 43847831u64, hex::decode("15b9eeee849dd6386d3770b0745e0450190f7560e5159b1b3ab13b14b2684a45").unwrap(), @@ -106,9 +106,9 @@ fn main() { // setup the multiplexer by specifying the bearer and the IDs of the // miniprotocols to use let mut plexer = StdPlexer::new(bearer); - let channel0 = plexer.use_channel(0); - let channel3 = plexer.use_channel(3); - let channel2 = plexer.use_channel(2); + let channel0 = plexer.use_channel(0).into(); + let channel3 = plexer.use_channel(3).into(); + let channel2 = plexer.use_channel(2).into(); plexer.muxer.spawn(); plexer.demuxer.spawn(); diff --git a/pallas-miniprotocols/src/machines.rs b/pallas-miniprotocols/src/machines.rs index 2a2aeab4..91d0cce7 100644 --- a/pallas-miniprotocols/src/machines.rs +++ b/pallas-miniprotocols/src/machines.rs @@ -44,22 +44,22 @@ pub trait Agent: Sized { fn apply_inbound(self, msg: Self::Message) -> Transition; } -pub struct Runner<'c, A, C> +pub struct Runner where A: Agent, C: Channel, { agent: Cell>, - buffer: ChannelBuffer<'c, C>, + buffer: ChannelBuffer, } -impl<'c, A, C> Runner<'c, A, C> +impl Runner where A: Agent, A::Message: Fragment + std::fmt::Debug, C: Channel, { - pub fn new(agent: A, channel: &'c mut C) -> Self { + pub fn new(agent: A, channel: C) -> Self { Self { agent: Cell::new(Some(agent)), buffer: ChannelBuffer::new(channel), @@ -119,18 +119,16 @@ where } } -pub fn run_agent(agent: A, channel: &mut C) -> Transition +pub fn run_agent(agent: A, buffer: &mut ChannelBuffer) -> Transition where A: Agent, A::Message: Fragment + std::fmt::Debug, C: Channel, { - let mut buffer = ChannelBuffer::new(channel); - let mut agent = agent.apply_start()?; while !agent.is_done() { - agent = run_agent_step(agent, &mut buffer)?; + agent = run_agent_step(agent, buffer)?; } Ok(agent) diff --git a/pallas-multiplexer/src/agents.rs b/pallas-multiplexer/src/agents.rs index e5f74420..a00590e9 100644 --- a/pallas-multiplexer/src/agents.rs +++ b/pallas-multiplexer/src/agents.rs @@ -46,13 +46,13 @@ where } /// A channel abstraction to hide the complexity of partial payloads -pub struct ChannelBuffer<'c, C: Channel> { - channel: &'c mut C, +pub struct ChannelBuffer { + channel: C, temp: Vec, } -impl<'c, C: Channel> ChannelBuffer<'c, C> { - pub fn new(channel: &'c mut C) -> Self { +impl ChannelBuffer { + pub fn new(channel: C) -> Self { Self { channel, temp: Vec::new(), @@ -105,4 +105,14 @@ impl<'c, C: Channel> ChannelBuffer<'c, C> { } } } + + pub fn unwrap(self) -> C { + self.channel + } +} + +impl From for ChannelBuffer { + fn from(channel: C) -> Self { + ChannelBuffer::new(channel) + } } diff --git a/pallas-multiplexer/src/std.rs b/pallas-multiplexer/src/std.rs index fb5a09fe..d60a0424 100644 --- a/pallas-multiplexer/src/std.rs +++ b/pallas-multiplexer/src/std.rs @@ -1,4 +1,7 @@ -use crate::{agents, demux, mux, Payload}; +use crate::{ + agents::{self, ChannelBuffer}, + demux, mux, Payload, +}; use std::{ sync::{ @@ -103,6 +106,8 @@ impl demux::Demuxer { pub type StdChannel = (Sender, Receiver); +pub type StdChannelBuffer = ChannelBuffer; + impl agents::Channel for StdChannel { fn enqueue_chunk(&mut self, payload: Payload) -> Result<(), agents::ChannelError> { match self.0.send(payload) { diff --git a/pallas-multiplexer/tests/integration.rs b/pallas-multiplexer/tests/integration.rs index c217ed43..ecbec2de 100644 --- a/pallas-multiplexer/tests/integration.rs +++ b/pallas-multiplexer/tests/integration.rs @@ -72,10 +72,10 @@ fn multiple_messages_in_same_payload() { minicbor::encode(in_part1, &mut input).unwrap(); minicbor::encode(in_part2, &mut input).unwrap(); - let mut channel = std::sync::mpsc::channel(); + let channel = std::sync::mpsc::channel(); channel.0.send(input).unwrap(); - let mut buf = ChannelBuffer::new(&mut channel); + let mut buf = ChannelBuffer::new(channel); let out_part1 = buf.recv_full_msg::<(u8, u8, u8)>().unwrap(); let out_part2 = buf.recv_full_msg::<(u8, u8, u8)>().unwrap(); @@ -90,14 +90,14 @@ fn fragmented_message_in_multiple_payloads() { let msg = (11u8, 12u8, 13u8, 14u8, 15u8, 16u8, 17u8); minicbor::encode(msg, &mut input).unwrap(); - let mut channel = std::sync::mpsc::channel(); + let channel = std::sync::mpsc::channel(); while !input.is_empty() { let chunk = Vec::from(input.drain(0..2).as_slice()); channel.0.send(chunk).unwrap(); } - let mut buf = ChannelBuffer::new(&mut channel); + let mut buf = ChannelBuffer::new(channel); let out_msg = buf.recv_full_msg::<(u8, u8, u8, u8, u8, u8, u8)>().unwrap();