Skip to content

Commit

Permalink
Move SocketSniffer to separate module
Browse files Browse the repository at this point in the history
  • Loading branch information
Serock3 committed Dec 20, 2024
1 parent e02fe42 commit 0093cb2
Showing 1 changed file with 58 additions and 52 deletions.
110 changes: 58 additions & 52 deletions talpid-tunnel-config-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::net::SocketAddr;
#[cfg(not(target_os = "ios"))]
use std::net::{IpAddr, Ipv4Addr};
use talpid_types::net::wireguard::{PresharedKey, PublicKey};
use tokio::io::{AsyncRead, AsyncWrite};
use tonic::transport::Channel;
#[cfg(not(target_os = "ios"))]
use tonic::transport::Endpoint;
Expand Down Expand Up @@ -279,84 +278,91 @@ fn xor_assign(dst: &mut [u8; 32], src: &[u8; 32]) {
/// value has been speficically lowered, to avoid MTU issues. See the `socket` module.
#[cfg(not(target_os = "ios"))]
async fn connect_relay_config_client(ip: Ipv4Addr) -> Result<RelayConfigService, Error> {
use hyper_util::rt::tokio::TokioIo;

let endpoint = Endpoint::from_static("tcp://0.0.0.0:0");
let addr = SocketAddr::new(IpAddr::V4(ip), CONFIG_SERVICE_PORT);

let connection = endpoint
.connect_with_connector(service_fn(move |_| async move {
let sock = socket::TcpSocket::new()?;
let stream = sock.connect(addr).await?;
let sniffer = SocketSniffer {
let sniffer = socket_sniffer::SocketSniffer {
s: stream,
rx_bytes: 0,
tx_bytes: 0,
start_time: std::time::Instant::now(),
};
Ok::<_, std::io::Error>(hyper_util::rt::tokio::TokioIo::new(sniffer))
Ok::<_, std::io::Error>(TokioIo::new(sniffer))
}))
.await
.map_err(Error::GrpcConnectError)?;

Ok(RelayConfigService::new(connection))
}

struct SocketSniffer<S> {
s: S,
rx_bytes: usize,
tx_bytes: usize,
start_time: std::time::Instant,
}

impl<S> Drop for SocketSniffer<S> {
fn drop(&mut self) {
let duration = self.start_time.elapsed();
log::debug!(
"Tunnel config client connection ended. RX: {} bytes, TX: {} bytes, duration: {} s",
self.rx_bytes,
self.tx_bytes,
duration.as_secs()
);
mod socket_sniffer {
pub struct SocketSniffer<S> {
pub s: S,
pub rx_bytes: usize,
pub tx_bytes: usize,
pub start_time: std::time::Instant,
}
}
use std::{
io,
pin::Pin,
task::{Context, Poll},
};

use tokio::io::AsyncWrite;

impl<S: AsyncRead + AsyncWrite + std::marker::Unpin> AsyncRead for SocketSniffer<S> {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let bytes = std::task::ready!(std::pin::Pin::new(&mut self.s).poll_read(cx, buf));
if bytes.is_ok() {
self.rx_bytes += buf.filled().len();
use tokio::io::{AsyncRead, ReadBuf};

impl<S> Drop for SocketSniffer<S> {
fn drop(&mut self) {
let duration = self.start_time.elapsed();
log::debug!(
"Tunnel config client connection ended. RX: {} bytes, TX: {} bytes, duration: {} s",
self.rx_bytes,
self.tx_bytes,
duration.as_secs()
);
}
std::task::Poll::Ready(bytes)
}
}

impl<S: AsyncRead + AsyncWrite + std::marker::Unpin> AsyncWrite for SocketSniffer<S> {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
let bytes = std::task::ready!(std::pin::Pin::new(&mut self.s).poll_write(cx, buf));
if bytes.is_ok() {
self.tx_bytes += buf.len();
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for SocketSniffer<S> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let bytes = std::task::ready!(Pin::new(&mut self.s).poll_read(cx, buf));
if bytes.is_ok() {
self.rx_bytes += buf.filled().len();
}
Poll::Ready(bytes)
}
std::task::Poll::Ready(bytes)
}

fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.s).poll_flush(cx)
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for SocketSniffer<S> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let bytes = std::task::ready!(Pin::new(&mut self.s).poll_write(cx, buf));
if bytes.is_ok() {
self.tx_bytes += buf.len();
}
Poll::Ready(bytes)
}

fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.s).poll_shutdown(cx)
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.s).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.s).poll_shutdown(cx)
}
}
}

0 comments on commit 0093cb2

Please sign in to comment.