Skip to content

Commit

Permalink
refactor(autonat): use quick-protobuf-codec
Browse files Browse the repository at this point in the history
Resolves: #4489.
Resolves: #2500.

Pull-Request: #4787.
  • Loading branch information
thomaseizinger authored Nov 2, 2023
1 parent ac28488 commit d05d836
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 55 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions protocols/autonat/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ libp2p-identity = { workspace = true }
log = "0.4"
rand = "0.8"
quick-protobuf = "0.8"
quick-protobuf-codec = { workspace = true }
asynchronous-codec = "0.6.2"

[dev-dependencies]
async-std = { version = "1.10", features = ["attributes"] }
Expand Down
100 changes: 45 additions & 55 deletions protocols/autonat/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@

use crate::proto;
use async_trait::async_trait;
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use libp2p_core::{upgrade, Multiaddr};
use asynchronous_codec::{FramedRead, FramedWrite};
use futures::io::{AsyncRead, AsyncWrite};
use futures::{SinkExt, StreamExt};
use libp2p_core::Multiaddr;
use libp2p_identity::PeerId;
use libp2p_request_response::{self as request_response};
use libp2p_swarm::StreamProtocol;
use quick_protobuf::{BytesReader, Writer};
use std::{convert::TryFrom, io};

/// The protocol name used for negotiating with multistream-select.
Expand All @@ -44,8 +45,12 @@ impl request_response::Codec for AutoNatCodec {
where
T: AsyncRead + Send + Unpin,
{
let bytes = upgrade::read_length_prefixed(io, 1024).await?;
let request = DialRequest::from_bytes(&bytes)?;
let message = FramedRead::new(io, codec())
.next()
.await
.ok_or(io::ErrorKind::UnexpectedEof)??;
let request = DialRequest::from_proto(message)?;

Ok(request)
}

Expand All @@ -57,8 +62,12 @@ impl request_response::Codec for AutoNatCodec {
where
T: AsyncRead + Send + Unpin,
{
let bytes = upgrade::read_length_prefixed(io, 1024).await?;
let response = DialResponse::from_bytes(&bytes)?;
let message = FramedRead::new(io, codec())
.next()
.await
.ok_or(io::ErrorKind::UnexpectedEof)??;
let response = DialResponse::from_proto(message)?;

Ok(response)
}

Expand All @@ -71,8 +80,11 @@ impl request_response::Codec for AutoNatCodec {
where
T: AsyncWrite + Send + Unpin,
{
upgrade::write_length_prefixed(io, data.into_bytes()).await?;
io.close().await
let mut framed = FramedWrite::new(io, codec());
framed.send(data.into_proto()).await?;
framed.close().await?;

Ok(())
}

async fn write_response<T>(
Expand All @@ -84,24 +96,26 @@ impl request_response::Codec for AutoNatCodec {
where
T: AsyncWrite + Send + Unpin,
{
upgrade::write_length_prefixed(io, data.into_bytes()).await?;
io.close().await
let mut framed = FramedWrite::new(io, codec());
framed.send(data.into_proto()).await?;
framed.close().await?;

Ok(())
}
}

fn codec() -> quick_protobuf_codec::Codec<proto::Message> {
quick_protobuf_codec::Codec::<proto::Message>::new(1024)
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct DialRequest {
pub peer_id: PeerId,
pub addresses: Vec<Multiaddr>,
}

impl DialRequest {
pub fn from_bytes(bytes: &[u8]) -> Result<Self, io::Error> {
use quick_protobuf::MessageRead;

let mut reader = BytesReader::from_bytes(bytes);
let msg = proto::Message::from_reader(&mut reader, bytes)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
pub fn from_proto(msg: proto::Message) -> Result<Self, io::Error> {
if msg.type_pb != Some(proto::MessageType::DIAL) {
return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid type"));
}
Expand Down Expand Up @@ -143,17 +157,15 @@ impl DialRequest {
})
}

pub fn into_bytes(self) -> Vec<u8> {
use quick_protobuf::MessageWrite;

pub fn into_proto(self) -> proto::Message {
let peer_id = self.peer_id.to_bytes();
let addrs = self
.addresses
.into_iter()
.map(|addr| addr.to_vec())
.collect();

let msg = proto::Message {
proto::Message {
type_pb: Some(proto::MessageType::DIAL),
dial: Some(proto::Dial {
peer: Some(proto::PeerInfo {
Expand All @@ -162,12 +174,7 @@ impl DialRequest {
}),
}),
dialResponse: None,
};

let mut buf = Vec::with_capacity(msg.get_size());
let mut writer = Writer::new(&mut buf);
msg.write_message(&mut writer).expect("Encoding to succeed");
buf
}
}
}

Expand Down Expand Up @@ -217,12 +224,7 @@ pub struct DialResponse {
}

impl DialResponse {
pub fn from_bytes(bytes: &[u8]) -> Result<Self, io::Error> {
use quick_protobuf::MessageRead;

let mut reader = BytesReader::from_bytes(bytes);
let msg = proto::Message::from_reader(&mut reader, bytes)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
pub fn from_proto(msg: proto::Message) -> Result<Self, io::Error> {
if msg.type_pb != Some(proto::MessageType::DIAL_RESPONSE) {
return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid type"));
}
Expand Down Expand Up @@ -258,9 +260,7 @@ impl DialResponse {
})
}

pub fn into_bytes(self) -> Vec<u8> {
use quick_protobuf::MessageWrite;

pub fn into_proto(self) -> proto::Message {
let dial_response = match self.result {
Ok(addr) => proto::DialResponse {
status: Some(proto::ResponseStatus::OK),
Expand All @@ -274,23 +274,17 @@ impl DialResponse {
},
};

let msg = proto::Message {
proto::Message {
type_pb: Some(proto::MessageType::DIAL_RESPONSE),
dial: None,
dialResponse: Some(dial_response),
};

let mut buf = Vec::with_capacity(msg.get_size());
let mut writer = Writer::new(&mut buf);
msg.write_message(&mut writer).expect("Encoding to succeed");
buf
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use quick_protobuf::MessageWrite;

#[test]
fn test_request_encode_decode() {
Expand All @@ -301,8 +295,8 @@ mod tests {
"/ip4/192.168.1.42/tcp/30333".parse().unwrap(),
],
};
let bytes = request.clone().into_bytes();
let request2 = DialRequest::from_bytes(&bytes).unwrap();
let proto = request.clone().into_proto();
let request2 = DialRequest::from_proto(proto).unwrap();
assert_eq!(request, request2);
}

Expand All @@ -312,8 +306,8 @@ mod tests {
result: Ok("/ip4/8.8.8.8/tcp/30333".parse().unwrap()),
status_text: None,
};
let bytes = response.clone().into_bytes();
let response2 = DialResponse::from_bytes(&bytes).unwrap();
let proto = response.clone().into_proto();
let response2 = DialResponse::from_proto(proto).unwrap();
assert_eq!(response, response2);
}

Expand All @@ -323,8 +317,8 @@ mod tests {
result: Err(ResponseError::DialError),
status_text: Some("dial failed".to_string()),
};
let bytes = response.clone().into_bytes();
let response2 = DialResponse::from_bytes(&bytes).unwrap();
let proto = response.clone().into_proto();
let response2 = DialResponse::from_proto(proto).unwrap();
assert_eq!(response, response2);
}

Expand All @@ -350,11 +344,7 @@ mod tests {
dialResponse: None,
};

let mut bytes = Vec::with_capacity(msg.get_size());
let mut writer = Writer::new(&mut bytes);
msg.write_message(&mut writer).expect("Encoding to succeed");

let request = DialRequest::from_bytes(&bytes).expect("not to fail");
let request = DialRequest::from_proto(msg).expect("not to fail");

assert_eq!(request.addresses, vec![valid_multiaddr])
}
Expand Down

0 comments on commit d05d836

Please sign in to comment.