Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(network): refactor local state #326

Merged
merged 18 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions examples/n2c-miniprotocols/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
use pallas::network::{
facades::NodeClient,
miniprotocols::{chainsync, localstate, Point, MAINNET_MAGIC},
miniprotocols::{chainsync, localstate::queries_v16, Point, PRE_PRODUCTION_MAGIC},
};
use tracing::info;

async fn do_localstate_query(client: &mut NodeClient) {
client.statequery().acquire(None).await.unwrap();
let client = client.statequery();

let result = client
.statequery()
.query(localstate::queries::Request::GetSystemStart)
client.acquire(None).await.unwrap();

let result = queries_v16::get_chain_point(client).await.unwrap();
info!("result: {:?}", result);

let result = queries_v16::get_system_start(client).await.unwrap();
info!("result: {:?}", result);

let era = queries_v16::get_current_era(client).await.unwrap();
info!("result: {:?}", era);

let result = queries_v16::get_block_epoch_number(client, era)
.await
.unwrap();

info!("system start result: {:?}", result);
info!("result: {:?}", result);

client.send_release().await.unwrap();
}

async fn do_chainsync(client: &mut NodeClient) {
Expand Down Expand Up @@ -43,6 +54,10 @@ async fn do_chainsync(client: &mut NodeClient) {
}
}

// change the following to match the Cardano node socket in your local
// environment
const SOCKET_PATH: &str = "/tmp/node.socket";

#[cfg(target_family = "unix")]
#[tokio::main]
async fn main() {
Expand All @@ -55,15 +70,7 @@ async fn main() {

// we connect to the unix socket of the local node. Make sure you have the right
// path for your environment
let socket_path = "/tmp/node.socket";

// we connect to the unix socket of the local node and perform a handshake query
let version_table = NodeClient::handshake_query(socket_path, MAINNET_MAGIC)
.await
.unwrap();
info!("handshake query result: {:?}", version_table);

let mut client = NodeClient::connect(socket_path, MAINNET_MAGIC)
let mut client = NodeClient::connect(SOCKET_PATH, PRE_PRODUCTION_MAGIC)
.await
.unwrap();

Expand All @@ -75,7 +82,6 @@ async fn main() {
}

#[cfg(not(target_family = "unix"))]

fn main() {
panic!("can't use n2c unix socket on non-unix systems");
}
2 changes: 1 addition & 1 deletion pallas-network/src/facades.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ impl NodeServer {
plexer_handle,
version: ver,
chainsync: server_cs,
statequery: server_sq
statequery: server_sq,
})
} else {
plexer_handle.abort();
Expand Down
80 changes: 44 additions & 36 deletions pallas-network/src/miniprotocols/localstate/client.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
use pallas_codec::utils::AnyCbor;
use std::fmt::Debug;

use pallas_codec::Fragment;

use std::marker::PhantomData;
use thiserror::*;

use super::{AcquireFailure, Message, Query, State};
use super::{AcquireFailure, Message, State};
use crate::miniprotocols::Point;
use crate::multiplexer;

#[derive(Error, Debug)]
pub enum ClientError {
#[error("attempted to receive message while agency is ours")]
AgencyIsOurs,

#[error("attempted to send message while agency is theirs")]
AgencyIsTheirs,

#[error("inbound message is not valid for current state")]
InvalidInbound,

#[error("outbound message is not valid for current state")]
InvalidOutbound,

#[error("failure acquiring point, not found")]
AcquirePointNotFound,

#[error("failure acquiring point, too old")]
AcquirePointTooOld,

#[error("failure decoding CBOR data")]
InvalidCbor(pallas_codec::minicbor::decode::Error),

#[error("error while sending or receiving data through the channel")]
Plexer(multiplexer::Error),
}
Expand All @@ -36,22 +42,11 @@ impl From<AcquireFailure> for ClientError {
}
}

pub struct GenericClient<Q>(State, multiplexer::ChannelBuffer, PhantomData<Q>)
where
Q: Query,
Message<Q>: Fragment;
pub struct GenericClient(State, multiplexer::ChannelBuffer);

impl<Q> GenericClient<Q>
where
Q: Query,
Message<Q>: Fragment,
{
impl GenericClient {
pub fn new(channel: multiplexer::AgentChannel) -> Self {
Self(
State::Idle,
multiplexer::ChannelBuffer::new(channel),
PhantomData {},
)
Self(State::Idle, multiplexer::ChannelBuffer::new(channel))
}

pub fn state(&self) -> &State {
Expand Down Expand Up @@ -87,7 +82,7 @@ where
}
}

fn assert_outbound_state(&self, msg: &Message<Q>) -> Result<(), ClientError> {
fn assert_outbound_state(&self, msg: &Message) -> Result<(), ClientError> {
match (&self.0, msg) {
(State::Idle, Message::Acquire(_)) => Ok(()),
(State::Idle, Message::Done) => Ok(()),
Expand All @@ -98,7 +93,7 @@ where
}
}

fn assert_inbound_state(&self, msg: &Message<Q>) -> Result<(), ClientError> {
fn assert_inbound_state(&self, msg: &Message) -> Result<(), ClientError> {
match (&self.0, msg) {
(State::Acquiring, Message::Acquired) => Ok(()),
(State::Acquiring, Message::Failure(_)) => Ok(()),
Expand All @@ -107,15 +102,18 @@ where
}
}

pub async fn send_message(&mut self, msg: &Message<Q>) -> Result<(), ClientError> {
pub async fn send_message(&mut self, msg: &Message) -> Result<(), ClientError> {
self.assert_agency_is_ours()?;
self.assert_outbound_state(msg)?;
self.1.send_msg_chunks(msg).await.map_err(ClientError::Plexer)?;
self.1
.send_msg_chunks(msg)
.await
.map_err(ClientError::Plexer)?;

Ok(())
}

pub async fn recv_message(&mut self) -> Result<Message<Q>, ClientError> {
pub async fn recv_message(&mut self) -> Result<Message, ClientError> {
self.assert_agency_is_theirs()?;
let msg = self.1.recv_full_msg().await.map_err(ClientError::Plexer)?;
self.assert_inbound_state(&msg)?;
Expand All @@ -124,31 +122,31 @@ where
}

pub async fn send_acquire(&mut self, point: Option<Point>) -> Result<(), ClientError> {
let msg = Message::<Q>::Acquire(point);
let msg = Message::Acquire(point);
self.send_message(&msg).await?;
self.0 = State::Acquiring;

Ok(())
}

pub async fn send_reacquire(&mut self, point: Option<Point>) -> Result<(), ClientError> {
let msg = Message::<Q>::ReAcquire(point);
let msg = Message::ReAcquire(point);
self.send_message(&msg).await?;
self.0 = State::Acquiring;

Ok(())
}

pub async fn send_release(&mut self) -> Result<(), ClientError> {
let msg = Message::<Q>::Release;
let msg = Message::Release;
self.send_message(&msg).await?;
self.0 = State::Idle;

Ok(())
}

pub async fn send_done(&mut self) -> Result<(), ClientError> {
let msg = Message::<Q>::Done;
let msg = Message::Done;
self.send_message(&msg).await?;
self.0 = State::Done;

Expand All @@ -174,28 +172,38 @@ where
self.recv_while_acquiring().await
}

pub async fn send_query(&mut self, request: Q::Request) -> Result<(), ClientError> {
let msg = Message::<Q>::Query(request);
pub async fn send_query(&mut self, request: AnyCbor) -> Result<Message, ClientError> {
let msg = Message::Query(request);
self.send_message(&msg).await?;
self.0 = State::Querying;

Ok(())
Ok(msg)
}

pub async fn recv_while_querying(&mut self) -> Result<Q::Response, ClientError> {
pub async fn recv_while_querying(&mut self) -> Result<AnyCbor, ClientError> {
match self.recv_message().await? {
Message::Result(x) => {
Message::Result(result) => {
self.0 = State::Acquired;
Ok(x)
Ok(result)
}
_ => Err(ClientError::InvalidInbound),
}
}

pub async fn query(&mut self, request: Q::Request) -> Result<Q::Response, ClientError> {
pub async fn query_any(&mut self, request: AnyCbor) -> Result<AnyCbor, ClientError> {
self.send_query(request).await?;
self.recv_while_querying().await
}

pub async fn query<Q, R>(&mut self, request: Q) -> Result<R, ClientError>
where
Q: pallas_codec::minicbor::Encode<()>,
for<'b> R: pallas_codec::minicbor::Decode<'b, ()>,
{
let request = AnyCbor::from_encode(request);
let response = self.query_any(request).await?;
response.into_decode().map_err(ClientError::InvalidCbor)
}
}

pub type Client = GenericClient<super::queries::QueryV16>;
pub type Client = GenericClient;
16 changes: 3 additions & 13 deletions pallas-network/src/miniprotocols/localstate/codec.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use pallas_codec::minicbor::{decode, encode, Decode, Encode, Encoder};

use super::{AcquireFailure, Message, Query};
use super::{AcquireFailure, Message};

impl Encode<()> for AcquireFailure {
fn encode<W: encode::Write>(
Expand Down Expand Up @@ -36,12 +36,7 @@ impl<'b> Decode<'b, ()> for AcquireFailure {
}
}

impl<Q> Encode<()> for Message<Q>
where
Q: Query,
Q::Request: Encode<()>,
Q::Response: Encode<()>,
{
impl Encode<()> for Message {
fn encode<W: encode::Write>(
&self,
e: &mut Encoder<W>,
Expand Down Expand Up @@ -97,12 +92,7 @@ where
}
}

impl<'b, Q> Decode<'b, ()> for Message<Q>
where
Q: Query,
Q::Request: Decode<'b, ()>,
Q::Response: Decode<'b, ()>,
{
impl<'b> Decode<'b, ()> for Message {
fn decode(
d: &mut pallas_codec::minicbor::Decoder<'b>,
_ctx: &mut (),
Expand Down
3 changes: 2 additions & 1 deletion pallas-network/src/miniprotocols/localstate/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
mod client;
mod codec;
mod protocol;
pub mod queries;
mod server;

pub mod queries_v16;

pub use client::*;
pub use codec::*;
pub use protocol::*;
Expand Down
13 changes: 5 additions & 8 deletions pallas-network/src/miniprotocols/localstate/protocol.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::fmt::Debug;

use pallas_codec::utils::AnyCbor;

use crate::miniprotocols::Point;

#[derive(Debug, PartialEq, Eq, Clone)]
Expand All @@ -17,18 +19,13 @@ pub enum AcquireFailure {
PointNotOnChain,
}

pub trait Query: Debug {
type Request: Clone + Debug;
type Response: Clone + Debug;
}

#[derive(Debug)]
pub enum Message<Q: Query> {
pub enum Message {
Acquire(Option<Point>),
Failure(AcquireFailure),
Acquired,
Query(Q::Request),
Result(Q::Response),
Query(AnyCbor),
Result(AnyCbor),
ReAcquire(Option<Point>),
Release,
Done,
Expand Down
Loading