Skip to content

Commit

Permalink
draft changes to make add safe
Browse files Browse the repository at this point in the history
  • Loading branch information
keepsimple1 committed Feb 15, 2024
1 parent 77b4ed1 commit ad676e3
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 71 deletions.
4 changes: 1 addition & 3 deletions examples/tcp_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ use socket2::Type;
fn main() -> io::Result<()> {
let socket = socket2::Socket::new(socket2::Domain::IPV4, Type::STREAM, None)?;
let poller = polling::Poller::new()?;
unsafe {
poller.add(&socket, Event::new(0, true, true))?;
}
poller.add(&socket, Event::new(0, true, true))?;
let addr = net::SocketAddr::new(net::Ipv4Addr::LOCALHOST.into(), 8080);
socket.set_nonblocking(true)?;
let _ = socket.connect(&addr.into());
Expand Down
6 changes: 2 additions & 4 deletions examples/two-listeners.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ fn main() -> io::Result<()> {
l2.set_nonblocking(true)?;

let poller = Poller::new()?;
unsafe {
poller.add(&l1, Event::readable(1))?;
poller.add(&l2, Event::readable(2))?;
}
poller.add(&l1, Event::readable(1))?;
poller.add(&l2, Event::readable(2))?;

println!("You can connect to the server using `nc`:");
println!(" $ nc 127.0.0.1 8001");
Expand Down
12 changes: 9 additions & 3 deletions src/kqueue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,17 @@ impl Poller {
/// # Safety
///
/// The file descriptor must be valid and it must last until it is deleted.
pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
self.add_source(SourceId::Fd(fd))?;
pub fn add(&self, fd: BorrowedFd<'_>, ev: Event, mode: PollMode) -> io::Result<()> {
let rawfd = fd.as_raw_fd();

// SAFETY: `rawfd` is valid as it is from `BorrowedFd`. And
// this block never closes / deletes `rawfd`.
unsafe {
self.add_source(SourceId::Fd(rawfd))?;
}

// File descriptors don't need to be added explicitly, so just modify the interest.
self.modify(BorrowedFd::borrow_raw(fd), ev, mode)
self.modify(fd, ev, mode)
}

/// Modifies an existing file descriptor.
Expand Down
8 changes: 4 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ impl Poller {
/// poller.delete(&source)?;
/// # std::io::Result::Ok(())
/// ```
pub unsafe fn add(&self, source: impl AsRawSource, interest: Event) -> io::Result<()> {
pub fn add(&self, source: impl AsSource, interest: Event) -> io::Result<()> {
self.add_with_mode(source, interest, PollMode::Oneshot)
}

Expand All @@ -526,9 +526,9 @@ impl Poller {
///
/// If the operating system does not support the specified mode, this function
/// will return an error.
pub unsafe fn add_with_mode(
pub fn add_with_mode(
&self,
source: impl AsRawSource,
source: impl AsSource,
interest: Event,
mode: PollMode,
) -> io::Result<()> {
Expand All @@ -538,7 +538,7 @@ impl Poller {
"the key is not allowed to be `usize::MAX`",
));
}
self.poller.add(source.raw(), interest, mode)
self.poller.add(source.as_fd(), interest, mode)
}

/// Modifies the interest in a file descriptor or socket.
Expand Down
12 changes: 3 additions & 9 deletions tests/concurrent_modification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ fn concurrent_add() -> io::Result<()> {
})
.add(|| {
thread::sleep(Duration::from_millis(100));
unsafe {
poller.add(&reader, Event::readable(0))?;
}
poller.add(&reader, Event::readable(0))?;
writer.write_all(&[1])?;
Ok(())
})
Expand All @@ -46,9 +44,7 @@ fn concurrent_add() -> io::Result<()> {
fn concurrent_modify() -> io::Result<()> {
let (reader, mut writer) = tcp_pair()?;
let poller = Poller::new()?;
unsafe {
poller.add(&reader, Event::none(0))?;
}
poller.add(&reader, Event::none(0))?;

let mut events = Events::new();

Expand Down Expand Up @@ -84,9 +80,7 @@ fn concurrent_interruption() -> io::Result<()> {

let (reader, _writer) = tcp_pair()?;
let poller = Poller::new()?;
unsafe {
poller.add(&reader, Event::none(0))?;
}
poller.add(&reader, Event::none(0))?;

let mut events = Events::new();
let events_borrow = &mut events;
Expand Down
26 changes: 11 additions & 15 deletions tests/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ use std::time::Duration;
fn basic_io() {
let poller = Poller::new().unwrap();
let (read, mut write) = tcp_pair().unwrap();
unsafe {
poller.add(&read, Event::readable(1)).unwrap();
}
poller.add(&read, Event::readable(1)).unwrap();

// Nothing should be available at first.
let mut events = Events::new();
Expand Down Expand Up @@ -42,26 +40,24 @@ fn basic_io() {
#[test]
fn insert_twice() {
#[cfg(unix)]
use std::os::unix::io::AsRawFd;
use std::os::unix::io::AsFd;
#[cfg(windows)]
use std::os::windows::io::AsRawSocket;

let (read, mut write) = tcp_pair().unwrap();
let read = Arc::new(read);

let poller = Poller::new().unwrap();
unsafe {
#[cfg(unix)]
let read = read.as_raw_fd();
#[cfg(windows)]
let read = read.as_raw_socket();
#[cfg(unix)]
let read = read.as_fd();
#[cfg(windows)]
let read = read.as_raw_socket();

poller.add(read, Event::readable(1)).unwrap();
assert_eq!(
poller.add(read, Event::readable(1)).unwrap_err().kind(),
io::ErrorKind::AlreadyExists
);
}
poller.add(read, Event::readable(1)).unwrap();
assert_eq!(
poller.add(read, Event::readable(1)).unwrap_err().kind(),
io::ErrorKind::AlreadyExists
);

write.write_all(&[1]).unwrap();
let mut events = Events::new();
Expand Down
4 changes: 1 addition & 3 deletions tests/many_connections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ fn many_connections() {
let poller = polling::Poller::new().unwrap();

for (i, reader, _) in connections.iter() {
unsafe {
poller.add(reader, polling::Event::readable(*i)).unwrap();
}
poller.add(reader, polling::Event::readable(*i)).unwrap();
}

let mut events = Events::new();
Expand Down
42 changes: 18 additions & 24 deletions tests/multiple_pollers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@ fn level_triggered() {

// Register the source into both pollers.
let (mut reader, mut writer) = tcp_pair().unwrap();
unsafe {
poller1
.add_with_mode(&reader, Event::readable(1), PollMode::Level)
.unwrap();
poller2
.add_with_mode(&reader, Event::readable(2), PollMode::Level)
.unwrap();
}
poller1
.add_with_mode(&reader, Event::readable(1), PollMode::Level)
.unwrap();
poller2
.add_with_mode(&reader, Event::readable(2), PollMode::Level)
.unwrap();

// Neither poller should have any events.
assert_eq!(
Expand Down Expand Up @@ -139,14 +137,12 @@ fn edge_triggered() {

// Register the source into both pollers.
let (mut reader, mut writer) = tcp_pair().unwrap();
unsafe {
poller1
.add_with_mode(&reader, Event::readable(1), PollMode::Edge)
.unwrap();
poller2
.add_with_mode(&reader, Event::readable(2), PollMode::Edge)
.unwrap();
}
poller1
.add_with_mode(&reader, Event::readable(1), PollMode::Edge)
.unwrap();
poller2
.add_with_mode(&reader, Event::readable(2), PollMode::Edge)
.unwrap();

// Neither poller should have any events.
assert_eq!(
Expand Down Expand Up @@ -256,14 +252,12 @@ fn oneshot_triggered() {

// Register the source into both pollers.
let (mut reader, mut writer) = tcp_pair().unwrap();
unsafe {
poller1
.add_with_mode(&reader, Event::readable(1), PollMode::Oneshot)
.unwrap();
poller2
.add_with_mode(&reader, Event::readable(2), PollMode::Oneshot)
.unwrap();
}
poller1
.add_with_mode(&reader, Event::readable(1), PollMode::Oneshot)
.unwrap();
poller2
.add_with_mode(&reader, Event::readable(2), PollMode::Oneshot)
.unwrap();

// Neither poller should have any events.
assert_eq!(
Expand Down
13 changes: 7 additions & 6 deletions tests/other_modes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ fn level_triggered() {

// Create our poller and register our streams.
let poller = Poller::new().unwrap();
if unsafe { poller.add_with_mode(&reader, Event::readable(reader_token), PollMode::Level) }
if poller
.add_with_mode(&reader, Event::readable(reader_token), PollMode::Level)
.is_err()
{
// Only panic if we're on a platform that should support level mode.
Expand Down Expand Up @@ -104,7 +105,8 @@ fn edge_triggered() {

// Create our poller and register our streams.
let poller = Poller::new().unwrap();
if unsafe { poller.add_with_mode(&reader, Event::readable(reader_token), PollMode::Edge) }
if poller
.add_with_mode(&reader, Event::readable(reader_token), PollMode::Edge)
.is_err()
{
// Only panic if we're on a platform that should support level mode.
Expand Down Expand Up @@ -194,14 +196,13 @@ fn edge_oneshot_triggered() {

// Create our poller and register our streams.
let poller = Poller::new().unwrap();
if unsafe {
poller.add_with_mode(
if poller
.add_with_mode(
&reader,
Event::readable(reader_token),
PollMode::EdgeOneshot,
)
}
.is_err()
.is_err()
{
// Only panic if we're on a platform that should support level mode.
cfg_if::cfg_if! {
Expand Down

0 comments on commit ad676e3

Please sign in to comment.