Skip to content

Commit

Permalink
net: add unix stream & listener (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-chuang authored Mar 2, 2022
1 parent 523dae0 commit 3176149
Show file tree
Hide file tree
Showing 15 changed files with 283 additions and 68 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ scoped-tls = "1.0.0"
slab = "0.4.2"
libc = "0.2.80"
io-uring = { version = "0.5.0", features = [ "unstable" ] }
os_socketaddr = "0.2.0"
socket2 = { version = "0.4.4", features = [ "all"] }
bytes = { version = "1.0", optional = true }

[dev-dependencies]
Expand Down
32 changes: 32 additions & 0 deletions examples/unix_listener.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use std::env;

use tokio_uring::net::UnixListener;

fn main() {
let args: Vec<_> = env::args().collect();

if args.len() <= 1 {
panic!("no addr specified");
}

let socket_addr: String = args[1].clone();

tokio_uring::start(async {
let listener = UnixListener::bind(&socket_addr).unwrap();

loop {
let stream = listener.accept().await.unwrap();
let socket_addr = socket_addr.clone();
tokio_uring::spawn(async move {
let buf = vec![1u8; 128];

let (result, buf) = stream.write(buf).await;
println!("written to {}: {}", &socket_addr, result.unwrap());

let (result, buf) = stream.read(buf).await;
let read = result.unwrap();
println!("read from {}: {:?}", &socket_addr, &buf[..read]);
});
}
});
}
25 changes: 25 additions & 0 deletions examples/unix_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use std::env;

use tokio_uring::net::UnixStream;

fn main() {
let args: Vec<_> = env::args().collect();

if args.len() <= 1 {
panic!("no addr specified");
}

let socket_addr: &String = &args[1];

tokio_uring::start(async {
let stream = UnixStream::connect(socket_addr).await.unwrap();
let buf = vec![1u8; 128];

let (result, buf) = stream.write(buf).await;
println!("written: {}", result.unwrap());

let (result, buf) = stream.read(buf).await;
let read = result.unwrap();
println!("read: {:?}", &buf[..read]);
});
}
16 changes: 7 additions & 9 deletions src/driver/connect.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,28 @@
use crate::driver::{Op, SharedFd};
use os_socketaddr::OsSocketAddr;
use std::{io, net::SocketAddr};
use socket2::SockAddr;
use std::io;

/// Open a file
pub(crate) struct Connect {
fd: SharedFd,
os_socket_addr: OsSocketAddr,
socket_addr: SockAddr,
}

impl Op<Connect> {
/// Submit a request to connect.
pub(crate) fn connect(fd: &SharedFd, socket_addr: SocketAddr) -> io::Result<Op<Connect>> {
pub(crate) fn connect(fd: &SharedFd, socket_addr: SockAddr) -> io::Result<Op<Connect>> {
use io_uring::{opcode, types};

let os_socket_addr = OsSocketAddr::from(socket_addr);

Op::submit_with(
Connect {
fd: fd.clone(),
os_socket_addr,
socket_addr,
},
|connect| {
opcode::Connect::new(
types::Fd(connect.fd.raw_fd()),
connect.os_socket_addr.as_ptr(),
connect.os_socket_addr.len(),
connect.socket_addr.as_ptr(),
connect.socket_addr.len(),
)
.build()
},
Expand Down
2 changes: 1 addition & 1 deletion src/driver/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ where
Poll::Ready(Completion {
data: me.data.take().expect("unexpected operation state"),
result,
flags: flags,
flags,
})
}
}
Expand Down
14 changes: 7 additions & 7 deletions src/driver/recv_from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
driver::{Op, SharedFd},
BufResult,
};
use os_socketaddr::OsSocketAddr;
use socket2::SockAddr;
use std::{
io::IoSliceMut,
task::{Context, Poll},
Expand All @@ -15,7 +15,7 @@ pub(crate) struct RecvFrom<T> {
fd: SharedFd,
pub(crate) buf: T,
io_slices: Vec<IoSliceMut<'static>>,
pub(crate) os_socket_addr: Box<OsSocketAddr>,
pub(crate) socket_addr: Box<SockAddr>,
pub(crate) msghdr: Box<libc::msghdr>,
}

Expand All @@ -27,20 +27,20 @@ impl<T: IoBufMut> Op<RecvFrom<T>> {
std::slice::from_raw_parts_mut(buf.stable_mut_ptr(), buf.bytes_total())
})];

let mut os_socket_addr = Box::new(OsSocketAddr::new());
let socket_addr = Box::new(unsafe { SockAddr::init(|_, _| Ok(()))?.1 });

let mut msghdr: Box<libc::msghdr> = Box::new(unsafe { std::mem::zeroed() });
msghdr.msg_iov = io_slices.as_mut_ptr().cast();
msghdr.msg_iovlen = io_slices.len() as _;
msghdr.msg_name = os_socket_addr.as_mut_ptr() as *mut libc::c_void;
msghdr.msg_namelen = os_socket_addr.capacity();
msghdr.msg_name = socket_addr.as_ptr() as *mut libc::c_void;
msghdr.msg_namelen = socket_addr.len();

Op::submit_with(
RecvFrom {
fd: fd.clone(),
buf,
io_slices,
os_socket_addr,
socket_addr,
msghdr,
},
|recv_from| {
Expand Down Expand Up @@ -74,7 +74,7 @@ impl<T: IoBufMut> Op<RecvFrom<T>> {
let result = match complete.result {
Ok(v) => {
let v = v as usize;
let socket_addr: Option<SocketAddr> = (*complete.data.os_socket_addr).into();
let socket_addr: Option<SocketAddr> = (*complete.data.socket_addr).as_socket();
// If the operation was successful, advance the initialized cursor.
// Safety: the kernel wrote `v` bytes to the buffer.
unsafe {
Expand Down
12 changes: 6 additions & 6 deletions src/driver/send_to.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::buf::IoBuf;
use crate::driver::{Op, SharedFd};
use crate::BufResult;
use os_socketaddr::OsSocketAddr;
use socket2::SockAddr;
use std::io::IoSlice;
use std::task::{Context, Poll};
use std::{boxed::Box, io, net::SocketAddr};
Expand All @@ -13,7 +13,7 @@ pub(crate) struct SendTo<T> {
#[allow(dead_code)]
io_slices: Vec<IoSlice<'static>>,
#[allow(dead_code)]
os_socket_addr: Box<OsSocketAddr>,
socket_addr: Box<SockAddr>,
pub(crate) msghdr: Box<libc::msghdr>,
}

Expand All @@ -29,20 +29,20 @@ impl<T: IoBuf> Op<SendTo<T>> {
std::slice::from_raw_parts(buf.stable_ptr(), buf.bytes_init())
})];

let mut os_socket_addr = Box::new(OsSocketAddr::from(socket_addr));
let socket_addr = Box::new(SockAddr::from(socket_addr));

let mut msghdr: Box<libc::msghdr> = Box::new(unsafe { std::mem::zeroed() });
msghdr.msg_iov = io_slices.as_ptr() as *mut _;
msghdr.msg_iovlen = io_slices.len() as _;
msghdr.msg_name = os_socket_addr.as_mut_ptr() as *mut libc::c_void;
msghdr.msg_namelen = os_socket_addr.len();
msghdr.msg_name = socket_addr.as_ptr() as *mut libc::c_void;
msghdr.msg_namelen = socket_addr.len();

Op::submit_with(
SendTo {
fd: fd.clone(),
buf,
io_slices,
os_socket_addr,
socket_addr,
msghdr,
},
|send_to| {
Expand Down
76 changes: 56 additions & 20 deletions src/driver/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ use crate::{
buf::{IoBuf, IoBufMut},
driver::{Op, SharedFd},
};
use os_socketaddr::OsSocketAddr;
use std::{
io,
net::SocketAddr,
os::unix::io::{AsRawFd, RawFd},
os::unix::io::{AsRawFd, IntoRawFd, RawFd},
path::Path,
};

#[derive(Clone)]
Expand All @@ -26,7 +26,15 @@ impl Socket {
pub(crate) fn new(socket_addr: SocketAddr, socket_type: libc::c_int) -> io::Result<Socket> {
let socket_type = socket_type | libc::SOCK_CLOEXEC;
let domain = get_domain(socket_addr);
let fd = syscall!(socket(domain, socket_type, 0))?;
let fd = socket2::Socket::new(domain.into(), socket_type.into(), None)?.into_raw_fd();
let fd = SharedFd::new(fd);
Ok(Socket { fd })
}

pub(crate) fn new_unix(socket_type: libc::c_int) -> io::Result<Socket> {
let socket_type = socket_type | libc::SOCK_CLOEXEC;
let domain = libc::AF_UNIX;
let fd = socket2::Socket::new(domain.into(), socket_type.into(), None)?.into_raw_fd();
let fd = SharedFd::new(fd);
Ok(Socket { fd })
}
Expand Down Expand Up @@ -58,38 +66,66 @@ impl Socket {
op.recv().await
}

pub(crate) async fn accept(&self) -> io::Result<(Socket, SocketAddr)> {
pub(crate) async fn accept(&self) -> io::Result<(Socket, Option<SocketAddr>)> {
let op = Op::accept(&self.fd)?;
let completion = op.await;
let fd = completion.result?;
let fd = SharedFd::new(fd as i32);
let data = completion.data;
let socket = Socket { fd };
let os_socket_addr = unsafe {
OsSocketAddr::from_raw_parts(
&completion.data.socketaddr.0 as *const _ as _,
completion.data.socketaddr.1 as usize,
)
let (_, addr) = unsafe {
socket2::SockAddr::init(move |addr_storage, len| {
*addr_storage = data.socketaddr.0.to_owned();
*len = data.socketaddr.1;
Ok(())
})?
};
let socket_addr = os_socket_addr.into_addr().unwrap();
Ok((socket, socket_addr))
Ok((socket, addr.as_socket()))
}

pub(crate) async fn connect(&self, socket_addr: SocketAddr) -> io::Result<()> {
pub(crate) async fn connect(&self, socket_addr: socket2::SockAddr) -> io::Result<()> {
let op = Op::connect(&self.fd, socket_addr)?;
let completion = op.await;
completion.result?;
Ok(())
}

pub(crate) fn bind(socket_addr: SocketAddr, socket_type: libc::c_int) -> io::Result<Socket> {
let socket = Socket::new(socket_addr, socket_type)?;
let os_socket_addr = OsSocketAddr::from(socket_addr);
syscall!(bind(
socket.as_raw_fd(),
os_socket_addr.as_ptr(),
os_socket_addr.len()
))?;
Ok(socket)
Self::bind_internal(
socket_addr.into(),
get_domain(socket_addr).into(),
socket_type.into(),
)
}

pub(crate) fn bind_unix<P: AsRef<Path>>(
path: P,
socket_type: libc::c_int,
) -> io::Result<Socket> {
let addr = socket2::SockAddr::unix(path.as_ref())?;
Self::bind_internal(addr, libc::AF_UNIX.into(), socket_type.into())
}

fn bind_internal(
socket_addr: socket2::SockAddr,
domain: socket2::Domain,
socket_type: socket2::Type,
) -> io::Result<Socket> {
let sys_listener = socket2::Socket::new(domain, socket_type, None)?;
let addr = socket2::SockAddr::from(socket_addr);

sys_listener.set_reuse_port(true)?;
sys_listener.set_reuse_address(true)?;

// TODO: config for buffer sizes
// sys_listener.set_send_buffer_size(send_buf_size)?;
// sys_listener.set_recv_buffer_size(recv_buf_size)?;

sys_listener.bind(&addr)?;

let fd = SharedFd::new(sys_listener.into_raw_fd());

Ok(Self { fd })
}

pub(crate) fn listen(&self, backlog: libc::c_int) -> io::Result<()> {
Expand Down
2 changes: 2 additions & 0 deletions src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
mod tcp;
mod udp;
mod unix;

pub use tcp::{TcpListener, TcpStream};
pub use udp::UdpSocket;
pub use unix::{UnixListener, UnixStream};
13 changes: 9 additions & 4 deletions src/net/tcp/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ impl TcpListener {
/// The returned listener is ready for accepting connections.
///
/// Binding with a port number of 0 will request that the OS assigns a port
/// to this listener. The port allocated can be queried via the `local_addr`
/// to this listener.
///
/// In the future, the port allocated can be queried via a (blocking) `local_addr`
/// method.
pub fn bind(socket_addr: SocketAddr) -> io::Result<TcpListener> {
let socket = Socket::bind(socket_addr, libc::SOCK_STREAM)?;
pub fn bind(addr: SocketAddr) -> io::Result<Self> {
let socket = Socket::bind(addr, libc::SOCK_STREAM)?;
socket.listen(1024)?;
Ok(TcpListener { inner: socket })
return Ok(TcpListener { inner: socket });
}

/// Accepts a new incoming connection from this listener.
Expand All @@ -59,6 +61,9 @@ impl TcpListener {
pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
let (socket, socket_addr) = self.inner.accept().await?;
let stream = TcpStream { inner: socket };
let socket_addr = socket_addr.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "Could not get socket IP address")
})?;
Ok((stream, socket_addr))
}
}
Loading

0 comments on commit 3176149

Please sign in to comment.