Skip to content

Commit

Permalink
add a macro to declare thread unblock callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
RalfJung committed May 26, 2024
1 parent e6bb468 commit 2e89443
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 249 deletions.
2 changes: 1 addition & 1 deletion src/tools/miri/src/concurrency/init_once.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
fn init_once_enqueue_and_block(
&mut self,
id: InitOnceId,
callback: impl UnblockCallback<'mir, 'tcx> + 'tcx,
callback: impl UnblockCallback<'tcx> + 'tcx,
) {
let this = self.eval_context_mut();
let thread = this.active_thread();
Expand Down
301 changes: 130 additions & 171 deletions src/tools/miri/src/concurrency/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ macro_rules! declare_id {
}
}

impl $crate::VisitProvenance for $name {
fn visit_provenance(&self, _visit: &mut VisitWith<'_>) {}
}

impl Idx for $name {
fn new(idx: usize) -> Self {
// We use 0 as a sentinel value (see the comment above) and,
Expand Down Expand Up @@ -258,6 +262,25 @@ pub(super) trait EvalContextExtPriv<'mir, 'tcx: 'mir>:
Ok(new_index)
}
}

fn condvar_reacquire_mutex(
&mut self,
mutex: MutexId,
retval: Scalar<Provenance>,
dest: MPlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx> {
let this = self.eval_context_mut();
if this.mutex_is_locked(mutex) {
assert_ne!(this.mutex_get_owner(mutex), this.active_thread());
this.mutex_enqueue_and_block(mutex, retval, dest);
} else {
// We can have it right now!
this.mutex_lock(mutex);
// Don't forget to write the return value.
this.write_scalar(retval, &dest)?;
}
Ok(())
}
}

// Public interface to synchronization primitives. Please note that in most
Expand Down Expand Up @@ -384,29 +407,23 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
assert!(this.mutex_is_locked(id), "queing on unlocked mutex");
let thread = this.active_thread();
this.machine.sync.mutexes[id].queue.push_back(thread);
this.block_thread(BlockReason::Mutex(id), None, Callback { id, retval, dest });

struct Callback<'tcx> {
id: MutexId,
retval: Scalar<Provenance>,
dest: MPlaceTy<'tcx, Provenance>,
}
impl<'tcx> VisitProvenance for Callback<'tcx> {
fn visit_provenance(&self, visit: &mut VisitWith<'_>) {
let Callback { id: _, retval, dest } = self;
retval.visit_provenance(visit);
dest.visit_provenance(visit);
}
}
impl<'mir, 'tcx: 'mir> UnblockCallback<'mir, 'tcx> for Callback<'tcx> {
fn unblock(self: Box<Self>, this: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> {
assert!(!this.mutex_is_locked(self.id));
this.mutex_lock(self.id);

this.write_scalar(self.retval, &self.dest)?;
Ok(())
}
}
this.block_thread(
BlockReason::Mutex(id),
None,
callback!(
@capture<'tcx> {
id: MutexId,
retval: Scalar<Provenance>,
dest: MPlaceTy<'tcx, Provenance>,
}
@unblock = |this| {
assert!(!this.mutex_is_locked(id));
this.mutex_lock(id);
this.write_scalar(retval, &dest)?;
Ok(())
}
),
);
}

#[inline]
Expand Down Expand Up @@ -500,27 +517,22 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
let thread = this.active_thread();
assert!(this.rwlock_is_write_locked(id), "read-queueing on not write locked rwlock");
this.machine.sync.rwlocks[id].reader_queue.push_back(thread);
this.block_thread(BlockReason::RwLock(id), None, Callback { id, retval, dest });

struct Callback<'tcx> {
id: RwLockId,
retval: Scalar<Provenance>,
dest: MPlaceTy<'tcx, Provenance>,
}
impl<'tcx> VisitProvenance for Callback<'tcx> {
fn visit_provenance(&self, visit: &mut VisitWith<'_>) {
let Callback { id: _, retval, dest } = self;
retval.visit_provenance(visit);
dest.visit_provenance(visit);
}
}
impl<'mir, 'tcx: 'mir> UnblockCallback<'mir, 'tcx> for Callback<'tcx> {
fn unblock(self: Box<Self>, this: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> {
this.rwlock_reader_lock(self.id);
this.write_scalar(self.retval, &self.dest)?;
Ok(())
}
}
this.block_thread(
BlockReason::RwLock(id),
None,
callback!(
@capture<'tcx> {
id: RwLockId,
retval: Scalar<Provenance>,
dest: MPlaceTy<'tcx, Provenance>,
}
@unblock = |this| {
this.rwlock_reader_lock(id);
this.write_scalar(retval, &dest)?;
Ok(())
}
),
);
}

/// Lock by setting the writer that owns the lock.
Expand Down Expand Up @@ -588,27 +600,22 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
assert!(this.rwlock_is_locked(id), "write-queueing on unlocked rwlock");
let thread = this.active_thread();
this.machine.sync.rwlocks[id].writer_queue.push_back(thread);
this.block_thread(BlockReason::RwLock(id), None, Callback { id, retval, dest });

struct Callback<'tcx> {
id: RwLockId,
retval: Scalar<Provenance>,
dest: MPlaceTy<'tcx, Provenance>,
}
impl<'tcx> VisitProvenance for Callback<'tcx> {
fn visit_provenance(&self, visit: &mut VisitWith<'_>) {
let Callback { id: _, retval, dest } = self;
retval.visit_provenance(visit);
dest.visit_provenance(visit);
}
}
impl<'mir, 'tcx: 'mir> UnblockCallback<'mir, 'tcx> for Callback<'tcx> {
fn unblock(self: Box<Self>, this: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> {
this.rwlock_writer_lock(self.id);
this.write_scalar(self.retval, &self.dest)?;
Ok(())
}
}
this.block_thread(
BlockReason::RwLock(id),
None,
callback!(
@capture<'tcx> {
id: RwLockId,
retval: Scalar<Provenance>,
dest: MPlaceTy<'tcx, Provenance>,
}
@unblock = |this| {
this.rwlock_writer_lock(id);
this.write_scalar(retval, &dest)?;
Ok(())
}
),
);
}

/// Is the conditional variable awaited?
Expand Down Expand Up @@ -648,71 +655,37 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
this.block_thread(
BlockReason::Condvar(condvar),
timeout,
Callback { condvar, mutex, retval_succ, retval_timeout, dest },
);
return Ok(());

struct Callback<'tcx> {
condvar: CondvarId,
mutex: MutexId,
retval_succ: Scalar<Provenance>,
retval_timeout: Scalar<Provenance>,
dest: MPlaceTy<'tcx, Provenance>,
}
impl<'tcx> VisitProvenance for Callback<'tcx> {
fn visit_provenance(&self, visit: &mut VisitWith<'_>) {
let Callback { condvar: _, mutex: _, retval_succ, retval_timeout, dest } = self;
retval_succ.visit_provenance(visit);
retval_timeout.visit_provenance(visit);
dest.visit_provenance(visit);
}
}
impl<'tcx, 'mir> Callback<'tcx> {
#[allow(clippy::boxed_local)]
fn reacquire_mutex(
self: Box<Self>,
this: &mut MiriInterpCx<'mir, 'tcx>,
retval: Scalar<Provenance>,
) -> InterpResult<'tcx> {
if this.mutex_is_locked(self.mutex) {
assert_ne!(this.mutex_get_owner(self.mutex), this.active_thread());
this.mutex_enqueue_and_block(self.mutex, retval, self.dest);
} else {
// We can have it right now!
this.mutex_lock(self.mutex);
// Don't forget to write the return value.
this.write_scalar(retval, &self.dest)?;
callback!(
@capture<'tcx> {
condvar: CondvarId,
mutex: MutexId,
retval_succ: Scalar<Provenance>,
retval_timeout: Scalar<Provenance>,
dest: MPlaceTy<'tcx, Provenance>,
}
Ok(())
}
}
impl<'mir, 'tcx: 'mir> UnblockCallback<'mir, 'tcx> for Callback<'tcx> {
fn unblock(self: Box<Self>, this: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> {
// The condvar was signaled. Make sure we get the clock for that.
if let Some(data_race) = &this.machine.data_race {
data_race.acquire_clock(
&this.machine.sync.condvars[self.condvar].clock,
&this.machine.threads,
);
@unblock = |this| {
// The condvar was signaled. Make sure we get the clock for that.
if let Some(data_race) = &this.machine.data_race {
data_race.acquire_clock(
&this.machine.sync.condvars[condvar].clock,
&this.machine.threads,
);
}
// Try to acquire the mutex.
// The timeout only applies to the first wait (until the signal), not for mutex acquisition.
this.condvar_reacquire_mutex(mutex, retval_succ, dest)
}
// Try to acquire the mutex.
// The timeout only applies to the first wait (until the signal), not for mutex acquisition.
let retval = self.retval_succ;
self.reacquire_mutex(this, retval)
}
fn timeout(
self: Box<Self>,
this: &mut InterpCx<'mir, 'tcx, MiriMachine<'mir, 'tcx>>,
) -> InterpResult<'tcx> {
// We have to remove the waiter from the queue again.
let thread = this.active_thread();
let waiters = &mut this.machine.sync.condvars[self.condvar].waiters;
waiters.retain(|waiter| *waiter != thread);
// Now get back the lock.
let retval = self.retval_timeout;
self.reacquire_mutex(this, retval)
}
}
@timeout = |this| {
// We have to remove the waiter from the queue again.
let thread = this.active_thread();
let waiters = &mut this.machine.sync.condvars[condvar].waiters;
waiters.retain(|waiter| *waiter != thread);
// Now get back the lock.
this.condvar_reacquire_mutex(mutex, retval_timeout, dest)
}
),
);
return Ok(());
}

/// Wake up some thread (if there is any) sleeping on the conditional
Expand Down Expand Up @@ -755,50 +728,36 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
this.block_thread(
BlockReason::Futex { addr },
timeout,
Callback { addr, retval_succ, retval_timeout, dest, errno_timeout },
);

struct Callback<'tcx> {
addr: u64,
retval_succ: Scalar<Provenance>,
retval_timeout: Scalar<Provenance>,
dest: MPlaceTy<'tcx, Provenance>,
errno_timeout: Scalar<Provenance>,
}
impl<'tcx> VisitProvenance for Callback<'tcx> {
fn visit_provenance(&self, visit: &mut VisitWith<'_>) {
let Callback { addr: _, retval_succ, retval_timeout, dest, errno_timeout } = self;
retval_succ.visit_provenance(visit);
retval_timeout.visit_provenance(visit);
dest.visit_provenance(visit);
errno_timeout.visit_provenance(visit);
}
}
impl<'mir, 'tcx: 'mir> UnblockCallback<'mir, 'tcx> for Callback<'tcx> {
fn unblock(self: Box<Self>, this: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> {
let futex = this.machine.sync.futexes.get(&self.addr).unwrap();
// Acquire the clock of the futex.
if let Some(data_race) = &this.machine.data_race {
data_race.acquire_clock(&futex.clock, &this.machine.threads);
callback!(
@capture<'tcx> {
addr: u64,
retval_succ: Scalar<Provenance>,
retval_timeout: Scalar<Provenance>,
dest: MPlaceTy<'tcx, Provenance>,
errno_timeout: Scalar<Provenance>,
}
// Write the return value.
this.write_scalar(self.retval_succ, &self.dest)?;
Ok(())
}
fn timeout(
self: Box<Self>,
this: &mut InterpCx<'mir, 'tcx, MiriMachine<'mir, 'tcx>>,
) -> InterpResult<'tcx> {
// Remove the waiter from the futex.
let thread = this.active_thread();
let futex = this.machine.sync.futexes.get_mut(&self.addr).unwrap();
futex.waiters.retain(|waiter| waiter.thread != thread);
// Set errno and write return value.
this.set_last_error(self.errno_timeout)?;
this.write_scalar(self.retval_timeout, &self.dest)?;
Ok(())
}
}
@unblock = |this| {
let futex = this.machine.sync.futexes.get(&addr).unwrap();
// Acquire the clock of the futex.
if let Some(data_race) = &this.machine.data_race {
data_race.acquire_clock(&futex.clock, &this.machine.threads);
}
// Write the return value.
this.write_scalar(retval_succ, &dest)?;
Ok(())
}
@timeout = |this| {
// Remove the waiter from the futex.
let thread = this.active_thread();
let futex = this.machine.sync.futexes.get_mut(&addr).unwrap();
futex.waiters.retain(|waiter| waiter.thread != thread);
// Set errno and write return value.
this.set_last_error(errno_timeout)?;
this.write_scalar(retval_timeout, &dest)?;
Ok(())
}
),
);
}

/// Returns whether anything was woken.
Expand Down
Loading

0 comments on commit 2e89443

Please sign in to comment.