Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(multiplexer): Use buffers that own the inner channel #113

Merged
merged 1 commit into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/block-download/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ fn main() {
let bearer = Bearer::connect_tcp("relays-new.cardano-mainnet.iohk.io:3001").unwrap();

let mut plexer = StdPlexer::new(bearer);
let mut channel0 = plexer.use_channel(0);
let mut channel3 = plexer.use_channel(3);
let mut channel0 = plexer.use_channel(0).into();
let mut channel3 = plexer.use_channel(3).into();

plexer.muxer.spawn();
plexer.demuxer.spawn();
Expand Down
12 changes: 6 additions & 6 deletions examples/n2c-miniprotocols/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ impl chainsync::Observer<chainsync::HeaderContent> for LoggingObserver {
}
}

fn do_handshake(mut channel: multiplexer::StdChannel) {
fn do_handshake(mut channel: multiplexer::StdChannelBuffer) {
let versions = handshake::n2c::VersionTable::v1_and_above(MAINNET_MAGIC);
let _last = run_agent(handshake::Initiator::initial(versions), &mut channel).unwrap();
}

fn do_localstate_query(mut channel: multiplexer::StdChannel) {
fn do_localstate_query(mut channel: multiplexer::StdChannelBuffer) {
let agent = run_agent(
localstate::OneShotClient::<localstate::queries::QueryV10>::initial(
None,
Expand All @@ -60,7 +60,7 @@ fn do_localstate_query(mut channel: multiplexer::StdChannel) {
log::info!("state query result: {:?}", agent);
}

fn do_chainsync(mut channel: multiplexer::StdChannel) {
fn do_chainsync(mut channel: multiplexer::StdChannelBuffer) {
let known_points = vec![Point::Specific(
43847831u64,
hex::decode("15b9eeee849dd6386d3770b0745e0450190f7560e5159b1b3ab13b14b2684a45").unwrap(),
Expand Down Expand Up @@ -89,9 +89,9 @@ fn main() {
// setup the multiplexer by specifying the bearer and the IDs of the
// miniprotocols to use
let mut plexer = multiplexer::StdPlexer::new(bearer);
let channel0 = plexer.use_channel(0);
let channel7 = plexer.use_channel(7);
let channel5 = plexer.use_channel(5);
let channel0 = plexer.use_channel(0).into();
let channel7 = plexer.use_channel(7).into();
let channel5 = plexer.use_channel(5).into();

plexer.muxer.spawn();
plexer.demuxer.spawn();
Expand Down
14 changes: 7 additions & 7 deletions examples/n2n-miniprotocols/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use pallas::network::{
miniprotocols::{blockfetch, chainsync, handshake, run_agent, Point, MAINNET_MAGIC},
multiplexer::{bearers::Bearer, StdChannel, StdPlexer},
multiplexer::{agents::ChannelBuffer, bearers::Bearer, StdChannel, StdPlexer},
};

#[derive(Debug)]
Expand Down Expand Up @@ -50,12 +50,12 @@ impl chainsync::Observer<chainsync::HeaderContent> for LoggingObserver {
}
}

fn do_handshake(mut channel: StdChannel) {
fn do_handshake(mut channel: ChannelBuffer<StdChannel>) {
let versions = handshake::n2n::VersionTable::v4_and_above(MAINNET_MAGIC);
let _last = run_agent(handshake::Initiator::initial(versions), &mut channel).unwrap();
}

fn do_blockfetch(mut channel: StdChannel) {
fn do_blockfetch(mut channel: ChannelBuffer<StdChannel>) {
let range = (
Point::Specific(
43847831,
Expand All @@ -77,7 +77,7 @@ fn do_blockfetch(mut channel: StdChannel) {
println!("{:?}", agent);
}

fn do_chainsync(mut channel: StdChannel) {
fn do_chainsync(mut channel: ChannelBuffer<StdChannel>) {
let known_points = vec![Point::Specific(
43847831u64,
hex::decode("15b9eeee849dd6386d3770b0745e0450190f7560e5159b1b3ab13b14b2684a45").unwrap(),
Expand Down Expand Up @@ -106,9 +106,9 @@ fn main() {
// setup the multiplexer by specifying the bearer and the IDs of the
// miniprotocols to use
let mut plexer = StdPlexer::new(bearer);
let channel0 = plexer.use_channel(0);
let channel3 = plexer.use_channel(3);
let channel2 = plexer.use_channel(2);
let channel0 = plexer.use_channel(0).into();
let channel3 = plexer.use_channel(3).into();
let channel2 = plexer.use_channel(2).into();

plexer.muxer.spawn();
plexer.demuxer.spawn();
Expand Down
14 changes: 6 additions & 8 deletions pallas-miniprotocols/src/machines.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,22 @@ pub trait Agent: Sized {
fn apply_inbound(self, msg: Self::Message) -> Transition<Self>;
}

pub struct Runner<'c, A, C>
pub struct Runner<A, C>
where
A: Agent,
C: Channel,
{
agent: Cell<Option<A>>,
buffer: ChannelBuffer<'c, C>,
buffer: ChannelBuffer<C>,
}

impl<'c, A, C> Runner<'c, A, C>
impl<A, C> Runner<A, C>
where
A: Agent,
A::Message: Fragment + std::fmt::Debug,
C: Channel,
{
pub fn new(agent: A, channel: &'c mut C) -> Self {
pub fn new(agent: A, channel: C) -> Self {
Self {
agent: Cell::new(Some(agent)),
buffer: ChannelBuffer::new(channel),
Expand Down Expand Up @@ -119,18 +119,16 @@ where
}
}

pub fn run_agent<A, C>(agent: A, channel: &mut C) -> Transition<A>
pub fn run_agent<A, C>(agent: A, buffer: &mut ChannelBuffer<C>) -> Transition<A>
where
A: Agent,
A::Message: Fragment + std::fmt::Debug,
C: Channel,
{
let mut buffer = ChannelBuffer::new(channel);

let mut agent = agent.apply_start()?;

while !agent.is_done() {
agent = run_agent_step(agent, &mut buffer)?;
agent = run_agent_step(agent, buffer)?;
}

Ok(agent)
Expand Down
18 changes: 14 additions & 4 deletions pallas-multiplexer/src/agents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ where
}

/// A channel abstraction to hide the complexity of partial payloads
pub struct ChannelBuffer<'c, C: Channel> {
channel: &'c mut C,
pub struct ChannelBuffer<C: Channel> {
channel: C,
temp: Vec<u8>,
}

impl<'c, C: Channel> ChannelBuffer<'c, C> {
pub fn new(channel: &'c mut C) -> Self {
impl<C: Channel> ChannelBuffer<C> {
pub fn new(channel: C) -> Self {
Self {
channel,
temp: Vec::new(),
Expand Down Expand Up @@ -105,4 +105,14 @@ impl<'c, C: Channel> ChannelBuffer<'c, C> {
}
}
}

pub fn unwrap(self) -> C {
self.channel
}
}

impl<C: Channel> From<C> for ChannelBuffer<C> {
fn from(channel: C) -> Self {
ChannelBuffer::new(channel)
}
}
7 changes: 6 additions & 1 deletion pallas-multiplexer/src/std.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::{agents, demux, mux, Payload};
use crate::{
agents::{self, ChannelBuffer},
demux, mux, Payload,
};

use std::{
sync::{
Expand Down Expand Up @@ -103,6 +106,8 @@ impl demux::Demuxer<StdEgress> {

pub type StdChannel = (Sender<Payload>, Receiver<Payload>);

pub type StdChannelBuffer = ChannelBuffer<StdChannel>;

impl agents::Channel for StdChannel {
fn enqueue_chunk(&mut self, payload: Payload) -> Result<(), agents::ChannelError> {
match self.0.send(payload) {
Expand Down
8 changes: 4 additions & 4 deletions pallas-multiplexer/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ fn multiple_messages_in_same_payload() {
minicbor::encode(in_part1, &mut input).unwrap();
minicbor::encode(in_part2, &mut input).unwrap();

let mut channel = std::sync::mpsc::channel();
let channel = std::sync::mpsc::channel();
channel.0.send(input).unwrap();

let mut buf = ChannelBuffer::new(&mut channel);
let mut buf = ChannelBuffer::new(channel);

let out_part1 = buf.recv_full_msg::<(u8, u8, u8)>().unwrap();
let out_part2 = buf.recv_full_msg::<(u8, u8, u8)>().unwrap();
Expand All @@ -90,14 +90,14 @@ fn fragmented_message_in_multiple_payloads() {
let msg = (11u8, 12u8, 13u8, 14u8, 15u8, 16u8, 17u8);
minicbor::encode(msg, &mut input).unwrap();

let mut channel = std::sync::mpsc::channel();
let channel = std::sync::mpsc::channel();

while !input.is_empty() {
let chunk = Vec::from(input.drain(0..2).as_slice());
channel.0.send(chunk).unwrap();
}

let mut buf = ChannelBuffer::new(&mut channel);
let mut buf = ChannelBuffer::new(channel);

let out_msg = buf.recv_full_msg::<(u8, u8, u8, u8, u8, u8, u8)>().unwrap();

Expand Down