Skip to content

Commit

Permalink
In ConstructCoroutineInClosureShim, pass receiver by ref, not pointer
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Mar 25, 2024
1 parent 60b5ca6 commit dce205b
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 20 deletions.
6 changes: 3 additions & 3 deletions compiler/rustc_mir_transform/src/shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1015,8 +1015,8 @@ fn build_construct_coroutine_by_move_shim<'tcx>(
bug!();
};

// We use `*mut Self` here because we only need to emit an ABI-compatible shim body,
// rather than match the signature exactly.
// We use `&mut Self` here because we only need to emit an ABI-compatible shim body,
// rather than match the signature exactly (which might take `&self` instead).
//
// The self type here is a coroutine-closure, not a coroutine, and we never read from
// it because it never has any captures, because this is only true in the Fn/FnMut
Expand All @@ -1025,7 +1025,7 @@ fn build_construct_coroutine_by_move_shim<'tcx>(
if receiver_by_ref {
// Triple-check that there's no captures here.
assert_eq!(args.as_coroutine_closure().tupled_upvars_ty(), tcx.types.unit);
self_ty = Ty::new_mut_ptr(tcx, self_ty);
self_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, self_ty);
}

let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
Expand Down
8 changes: 6 additions & 2 deletions compiler/rustc_ty_utils/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,12 @@ fn fn_sig_for_fn_abi<'tcx>(
coroutine_kind = ty::ClosureKind::FnOnce;

// Implementations of `FnMut` and `Fn` for coroutine-closures
// still take their receiver by ref.
if receiver_by_ref { Ty::new_mut_ptr(tcx, coroutine_ty) } else { coroutine_ty }
// still take their receiver by (mut) ref.
if receiver_by_ref {
Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty)
} else {
coroutine_ty
}
} else {
tcx.closure_env_ty(coroutine_ty, coroutine_kind, env_region)
};
Expand Down
40 changes: 40 additions & 0 deletions src/tools/miri/tests/pass/async-closure-drop.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#![feature(async_closure, noop_waker, async_fn_traits)]

use std::future::Future;
use std::pin::pin;
use std::task::*;

pub fn block_on<T>(fut: impl Future<Output = T>) -> T {
let mut fut = pin!(fut);
let ctx = &mut Context::from_waker(Waker::noop());

loop {
match fut.as_mut().poll(ctx) {
Poll::Pending => {}
Poll::Ready(t) => break t,
}
}
}

async fn call_once(f: impl async FnOnce(DropMe)) {
f(DropMe("world")).await;
}

#[derive(Debug)]
struct DropMe(&'static str);

impl Drop for DropMe {
fn drop(&mut self) {
println!("{}", self.0);
}
}

pub fn main() {
block_on(async {
let b = DropMe("hello");
let async_closure = async move |a: DropMe| {
println!("{a:?} {b:?}");
};
call_once(async_closure).await;
});
}
3 changes: 3 additions & 0 deletions src/tools/miri/tests/pass/async-closure-drop.stdout
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
DropMe("world") DropMe("hello")
world
hello
36 changes: 24 additions & 12 deletions src/tools/miri/tests/pass/async-closure.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![feature(async_closure, noop_waker, async_fn_traits)]

use std::future::Future;
use std::ops::{AsyncFnMut, AsyncFnOnce};
use std::pin::pin;
use std::task::*;

Expand All @@ -16,25 +17,36 @@ pub fn block_on<T>(fut: impl Future<Output = T>) -> T {
}
}

async fn call_once(f: impl async FnOnce(DropMe)) {
f(DropMe("world")).await;
async fn call_mut(f: &mut impl AsyncFnMut(i32)) {
f(0).await;
}

#[derive(Debug)]
struct DropMe(&'static str);
async fn call_once(f: impl AsyncFnOnce(i32)) {
f(1).await;
}

impl Drop for DropMe {
fn drop(&mut self) {
println!("{}", self.0);
}
async fn call_normal<F: Future<Output = ()>>(f: &impl Fn(i32) -> F) {
f(0).await;
}

async fn call_normal_once<F: Future<Output = ()>>(f: impl FnOnce(i32) -> F) {
f(1).await;
}

pub fn main() {
block_on(async {
let b = DropMe("hello");
let async_closure = async move |a: DropMe| {
println!("{a:?} {b:?}");
let b = 2i32;
let mut async_closure = async move |a: i32| {
println!("{a} {b}");
};
call_mut(&mut async_closure).await;
call_once(async_closure).await;

// No-capture closures implement `Fn`.
let async_closure = async move |a: i32| {
println!("{a}");
};
call_normal(&async_closure).await;
call_normal_once(async_closure).await;
});
}
}
7 changes: 4 additions & 3 deletions src/tools/miri/tests/pass/async-closure.stdout
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
DropMe("world") DropMe("hello")
world
hello
0 2
1 2
0
1

0 comments on commit dce205b

Please sign in to comment.