diff --git a/src/async/win/device.rs b/src/async/win/device.rs index 8b67037d..e41c82bf 100644 --- a/src/async/win/device.rs +++ b/src/async/win/device.rs @@ -58,8 +58,8 @@ impl AsyncRead for AsyncDevice { ) -> Poll> { let rbuf = buf.initialize_unfilled(); match Pin::new(&mut self.inner).poll_read(cx, rbuf) { - Poll::Ready(Ok(n)) => { - buf.advance(n); + Poll::Ready(Ok(size)) => { + buf.advance(size); Poll::Ready(Ok(())) } Poll::Ready(Err(e)) => Poll::Ready(Err(e)), @@ -126,10 +126,7 @@ impl AsyncRead for AsyncQueue { ) -> Poll> { let rbuf = buf.initialize_unfilled(); match Pin::new(&mut self.inner).poll_read(cx, rbuf) { - Poll::Ready(Ok(n)) => { - buf.advance(n); - Poll::Ready(Ok(())) - } + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), Poll::Ready(Err(e)) => Poll::Ready(Err(e)), Poll::Pending => Poll::Pending, } diff --git a/src/platform/windows/device.rs b/src/platform/windows/device.rs index 53bee9e4..1e5b0cbe 100644 --- a/src/platform/windows/device.rs +++ b/src/platform/windows/device.rs @@ -14,11 +14,11 @@ use std::io::{self, Read, Write}; use std::net::{IpAddr, Ipv4Addr}; +#[cfg(feature = "async")] use std::pin::Pin; -use std::sync::{Arc, Mutex}; -use std::task::{Context, Poll}; +use std::sync::Arc; +#[cfg(feature = "async")] use std::thread; -use std::vec::Vec; use wintun::Session; @@ -32,19 +32,44 @@ pub struct Device { mtu: usize, } +#[cfg(feature = "async")] +fn create_queue(session: Session) -> Queue { + let session = Arc::new(session); + let (receiver_tx, receiver_rx) = tokio::sync::mpsc::unbounded_channel::>(); + let session_reader = session.clone(); + let task = thread::spawn(move || { + while let Ok(packet) = session_reader.receive_blocking() { + let bytes = packet.bytes().to_vec(); + //dbg!(&bytes); + receiver_tx.send(bytes).unwrap(); + } + }); + Queue { + session, + receiver: receiver_rx, + _task: task, + } +} + +#[cfg(not(feature = "async"))] +fn create_queue(session: Session) -> Queue { + Queue { + session: Arc::new(session), + } +} + impl Device { /// Create a new `Device` for the given `Configuration`. pub fn new(config: &Configuration) -> Result { let wintun = unsafe { wintun::load()? }; let tun_name = config.name.as_deref().unwrap_or("wintun"); - let guid = Some(9099482345783245345345_u128); let adapter = match wintun::Adapter::open(&wintun, tun_name) { Ok(a) => a, - Err(_) => wintun::Adapter::create(&wintun, tun_name, tun_name, guid)?, + Err(_) => wintun::Adapter::create(&wintun, tun_name, tun_name, None)?, }; - let address = config.address.unwrap_or(Ipv4Addr::new(10, 1, 0, 2)); - let mask = config.netmask.unwrap_or(Ipv4Addr::new(255, 255, 255, 0)); + let address = config.address.ok_or(Error::InvalidConfig)?; + let mask = config.netmask.ok_or(Error::InvalidConfig)?; let gateway = config.destination.map(IpAddr::from); adapter.set_network_addresses_tuple(IpAddr::V4(address), IpAddr::V4(mask), gateway)?; let mtu = config.mtu.unwrap_or(1500) as usize; @@ -52,10 +77,7 @@ impl Device { let session = adapter.start_session(wintun::MAX_RING_CAPACITY)?; let mut device = Device { - queue: Queue { - session: Arc::new(session), - cached: Arc::new(Mutex::new(Vec::with_capacity(mtu))), - }, + queue: create_queue(session), mtu, }; @@ -64,24 +86,45 @@ impl Device { Ok(device) } +} +#[cfg(feature = "async")] +impl Device { pub fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, buf: &mut [u8], - ) -> Poll> { + ) -> std::task::Poll> { Pin::new(&mut self.queue).poll_read(cx, buf) } + + pub fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + Pin::new(&mut self.queue).poll_write(cx, buf) + } + + pub fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + pub fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } } impl Read for Device { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.queue.read(buf) } - - fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result { - self.queue.read_vectored(bufs) - } } impl Write for Device { @@ -89,12 +132,8 @@ impl Write for Device { self.queue.write(buf) } - fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { - self.queue.write_vectored(bufs) - } - fn flush(&mut self) -> io::Result<()> { - self.queue.flush() + Ok(()) } } @@ -183,7 +222,6 @@ impl D for Device { fn set_mtu(&mut self, value: i32) -> Result<()> { self.mtu = value as usize; - self.queue.cached = Arc::new(Mutex::new(Vec::with_capacity(self.mtu))); Ok(()) } @@ -194,99 +232,85 @@ impl D for Device { pub struct Queue { session: Arc, - cached: Arc>>, + #[cfg(feature = "async")] + receiver: tokio::sync::mpsc::UnboundedReceiver>, + #[cfg(feature = "async")] + _task: thread::JoinHandle<()>, } +#[cfg(feature = "async")] impl Queue { pub fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut buf: &mut [u8], - ) -> Poll> { - { - let mut cached = self - .cached - .lock() - .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?; - if cached.len() > 0 { - let res = match io::copy(&mut cached.as_slice(), &mut buf) { - Ok(n) => Poll::Ready(Ok(n as usize)), - Err(e) => Poll::Ready(Err(e)), - }; - cached.clear(); - return res; - } - } - let reader_session = self.session.clone(); - match reader_session.try_receive() { - Err(e) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))), - Ok(Some(packet)) => match io::copy(&mut packet.bytes(), &mut buf) { - Ok(n) => Poll::Ready(Ok(n as usize)), - Err(e) => Poll::Ready(Err(e)), - }, - Ok(None) => { - let waker = cx.waker().clone(); - let cached = self.cached.clone(); - thread::spawn(move || { - match reader_session.receive_blocking() { - Ok(packet) => { - if let Ok(mut cached) = cached.lock() { - cached.extend_from_slice(packet.bytes()); - } else { - log::error!("cached lock error in wintun reciever thread, packet will be dropped"); - } - } - Err(e) => log::error!("receive_blocking error: {:?}", e), - } - waker.wake() + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut [u8], + ) -> std::task::Poll> { + match std::task::ready!(self.receiver.poll_recv(cx)) { + Some(bytes) => { + //dbg!(buf.len(), bytes.len()); + bytes.iter().enumerate().for_each(|(index, value)| { + buf[index] = *value; }); - Poll::Pending + std::task::Poll::Ready(Ok(bytes.len())) } + None => std::task::Poll::Ready(Ok(0)), } } - #[allow(dead_code)] - fn try_read(&mut self, mut buf: &mut [u8]) -> io::Result { - let reader_session = self.session.clone(); - match reader_session.try_receive() { - Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)), - Ok(op) => match op { - None => Ok(0), - Some(packet) => match io::copy(&mut packet.bytes(), &mut buf) { - Ok(s) => Ok(s as usize), - Err(e) => Err(e), - }, - }, - } + pub fn poll_write( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + let mut write_pack = self.session.allocate_send_packet(buf.len() as u16)?; + write_pack.bytes_mut().copy_from_slice(buf.as_ref()); + self.session.send_packet(write_pack); + std::task::Poll::Ready(Ok(buf.len())) + } + + pub fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + pub fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) } } impl Read for Queue { - fn read(&mut self, mut buf: &mut [u8]) -> io::Result { - let reader_session = self.session.clone(); - match reader_session.receive_blocking() { - Ok(pkt) => match io::copy(&mut pkt.bytes(), &mut buf) { - Ok(n) => Ok(n as usize), - Err(e) => Err(e), - }, - Err(e) => Err(io::Error::new(io::ErrorKind::ConnectionAborted, e)), + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self.session.receive_blocking() { + Ok(pkt) => { + let bytes = pkt.bytes(); + let len = bytes.len(); + if len <= buf.len() { + buf[..len].clone_from_slice(bytes); + Ok(len) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "no large enough storage to save data", + )) + } + } + Err(_) => Err(std::io::Error::new(std::io::ErrorKind::NotConnected, "")), } } } impl Write for Queue { - fn write(&mut self, mut buf: &[u8]) -> io::Result { - let size = buf.len(); - match self.session.allocate_send_packet(size as u16) { - Err(e) => Err(io::Error::new(io::ErrorKind::OutOfMemory, e)), - Ok(mut packet) => match io::copy(&mut buf, &mut packet.bytes_mut()) { - Ok(s) => { - self.session.send_packet(packet); - Ok(s as usize) - } - Err(e) => Err(e), - }, - } + fn write(&mut self, buf: &[u8]) -> io::Result { + let len = buf.len(); + let mut write_pack = self.session.allocate_send_packet(len as u16)?; + write_pack.bytes_mut().copy_from_slice(buf.as_ref()); + self.session.send_packet(write_pack); + Ok(len) } fn flush(&mut self) -> io::Result<()> { diff --git a/wintun.dll b/wintun.dll new file mode 100644 index 00000000..aee04e77 Binary files /dev/null and b/wintun.dll differ