From 5df00fc14fa594603a827ed64113ad933317d73a Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Wed, 28 Feb 2024 20:25:25 +0000 Subject: [PATCH] Fix ABI for FnMut/Fn impls for async closures --- compiler/rustc_middle/src/mir/visit.rs | 1 + compiler/rustc_middle/src/ty/instance.rs | 11 ++++++++- compiler/rustc_mir_transform/src/shim.rs | 24 +++++++++++++++---- compiler/rustc_ty_utils/src/abi.rs | 15 ++++++++---- compiler/rustc_ty_utils/src/instance.rs | 2 ++ ...ure#0}.coroutine_by_move.0.panic-abort.mir | 2 +- ...re#0}.coroutine_by_move.0.panic-unwind.mir | 2 +- ...oroutine_closure_by_move.0.panic-abort.mir | 6 ++--- ...routine_closure_by_move.0.panic-unwind.mir | 6 ++--- ...coroutine_closure_by_ref.0.panic-abort.mir | 10 ++++++++ ...oroutine_closure_by_ref.0.panic-unwind.mir | 10 ++++++++ tests/mir-opt/async_closure_shims.rs | 10 ++++++++ 12 files changed, 81 insertions(+), 18 deletions(-) create mode 100644 tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.panic-abort.mir create mode 100644 tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.panic-unwind.mir diff --git a/compiler/rustc_middle/src/mir/visit.rs b/compiler/rustc_middle/src/mir/visit.rs index 14566f8c4f859..625aaf6f215f6 100644 --- a/compiler/rustc_middle/src/mir/visit.rs +++ b/compiler/rustc_middle/src/mir/visit.rs @@ -347,6 +347,7 @@ macro_rules! make_mir_visitor { ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } | ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id: _def_id, + receiver_by_ref: _, } | ty::InstanceDef::CoroutineKindShim { coroutine_def_id: _def_id } | ty::InstanceDef::DropGlue(_def_id, None) => {} diff --git a/compiler/rustc_middle/src/ty/instance.rs b/compiler/rustc_middle/src/ty/instance.rs index e097f830d85b2..2ccd83dcbbf4b 100644 --- a/compiler/rustc_middle/src/ty/instance.rs +++ b/compiler/rustc_middle/src/ty/instance.rs @@ -95,7 +95,15 @@ pub enum InstanceDef<'tcx> { /// The body generated here differs significantly from the `ClosureOnceShim`, /// since we need to generate a distinct coroutine type that will move the /// closure's upvars *out* of the closure. - ConstructCoroutineInClosureShim { coroutine_closure_def_id: DefId }, + ConstructCoroutineInClosureShim { + coroutine_closure_def_id: DefId, + // Whether the generated MIR body takes the coroutine by-ref. This is + // because the signature of `<{async fn} as FnMut>::call_mut` is: + // `fn(&mut self, args: A) -> ::Output`, that is to say + // that it returns the `FnOnce`-flavored coroutine but takes the closure + // by ref (and similarly for `Fn::call`). + receiver_by_ref: bool, + }, /// `<[coroutine] as Future>::poll`, but for coroutines produced when `AsyncFnOnce` /// is called on a coroutine-closure whose closure kind greater than `FnOnce`, or @@ -188,6 +196,7 @@ impl<'tcx> InstanceDef<'tcx> { | InstanceDef::ClosureOnceShim { call_once: def_id, track_caller: _ } | ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id: def_id, + receiver_by_ref: _, } | ty::InstanceDef::CoroutineKindShim { coroutine_def_id: def_id } | InstanceDef::DropGlue(def_id, _) diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index 3efaa69a7e780..4b2243598dc1d 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -70,9 +70,10 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' build_call_shim(tcx, instance, Some(Adjustment::RefMut), CallKind::Direct(call_mut)) } - ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id } => { - build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id) - } + ty::InstanceDef::ConstructCoroutineInClosureShim { + coroutine_closure_def_id, + receiver_by_ref, + } => build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id, receiver_by_ref), ty::InstanceDef::CoroutineKindShim { coroutine_def_id } => { return tcx.optimized_mir(coroutine_def_id).coroutine_by_move_body().unwrap().clone(); @@ -1015,12 +1016,17 @@ fn build_fn_ptr_addr_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'t fn build_construct_coroutine_by_move_shim<'tcx>( tcx: TyCtxt<'tcx>, coroutine_closure_def_id: DefId, + receiver_by_ref: bool, ) -> Body<'tcx> { - let self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity(); + let mut self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity(); let ty::CoroutineClosure(_, args) = *self_ty.kind() else { bug!(); }; + if receiver_by_ref { + self_ty = Ty::new_mut_ptr(tcx, self_ty); + } + let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| { tcx.mk_fn_sig( [self_ty].into_iter().chain(sig.tupled_inputs_ty.tuple_fields()), @@ -1076,11 +1082,19 @@ fn build_construct_coroutine_by_move_shim<'tcx>( let source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id, + receiver_by_ref, }); let body = new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span); - dump_mir(tcx, false, "coroutine_closure_by_move", &0, &body, |_, _| Ok(())); + dump_mir( + tcx, + false, + if receiver_by_ref { "coroutine_closure_by_ref" } else { "coroutine_closure_by_move" }, + &0, + &body, + |_, _| Ok(()), + ); body } diff --git a/compiler/rustc_ty_utils/src/abi.rs b/compiler/rustc_ty_utils/src/abi.rs index 3c439d95efffc..9b3b614e2bb35 100644 --- a/compiler/rustc_ty_utils/src/abi.rs +++ b/compiler/rustc_ty_utils/src/abi.rs @@ -118,11 +118,18 @@ fn fn_sig_for_fn_abi<'tcx>( // a separate def-id for these bodies. let mut coroutine_kind = args.as_coroutine_closure().kind(); - if let InstanceDef::ConstructCoroutineInClosureShim { .. } = instance.def { - coroutine_kind = ty::ClosureKind::FnOnce; - } + let env_ty = + if let InstanceDef::ConstructCoroutineInClosureShim { receiver_by_ref, .. } = + instance.def + { + coroutine_kind = ty::ClosureKind::FnOnce; - let env_ty = tcx.closure_env_ty(coroutine_ty, coroutine_kind, env_region); + // Implementations of `FnMut` and `Fn` for coroutine-closures + // still take their receiver by ref. + if receiver_by_ref { Ty::new_mut_ptr(tcx, coroutine_ty) } else { coroutine_ty } + } else { + tcx.closure_env_ty(coroutine_ty, coroutine_kind, env_region) + }; let sig = sig.skip_binder(); ty::Binder::bind_with_vars( diff --git a/compiler/rustc_ty_utils/src/instance.rs b/compiler/rustc_ty_utils/src/instance.rs index 05b75bfd0d19c..edcdc8c6220ec 100644 --- a/compiler/rustc_ty_utils/src/instance.rs +++ b/compiler/rustc_ty_utils/src/instance.rs @@ -283,6 +283,7 @@ fn resolve_associated_item<'tcx>( Some(Instance { def: ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id, + receiver_by_ref: target_kind != ty::ClosureKind::FnOnce, }, args, }) @@ -310,6 +311,7 @@ fn resolve_associated_item<'tcx>( Some(Instance { def: ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id, + receiver_by_ref: false, }, args, }) diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-abort.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-abort.mir index 6ca3dd6100572..06028487d0178 100644 --- a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-abort.mir +++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-abort.mir @@ -1,6 +1,6 @@ // MIR for `main::{closure#0}::{closure#0}::{closure#0}` 0 coroutine_by_move -fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10}, _2: ResumeTy) -> () +fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10}, _2: ResumeTy) -> () yields () { debug _task_context => _2; diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-unwind.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-unwind.mir index 6ca3dd6100572..06028487d0178 100644 --- a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-unwind.mir +++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-unwind.mir @@ -1,6 +1,6 @@ // MIR for `main::{closure#0}::{closure#0}::{closure#0}` 0 coroutine_by_move -fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10}, _2: ResumeTy) -> () +fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10}, _2: ResumeTy) -> () yields () { debug _task_context => _2; diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-abort.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-abort.mir index b5768e14452cd..93447b1388dea 100644 --- a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-abort.mir +++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-abort.mir @@ -1,10 +1,10 @@ // MIR for `main::{closure#0}::{closure#0}` 0 coroutine_closure_by_move -fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:37:33: 37:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10} { - let mut _0: {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10}; +fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:42:33: 42:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10} { + let mut _0: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10}; bb0: { - _0 = {coroutine@$DIR/async_closure_shims.rs:37:53: 40:10 (#0)} { a: move _2, b: move (_1.0: i32) }; + _0 = {coroutine@$DIR/async_closure_shims.rs:42:53: 45:10 (#0)} { a: move _2, b: move (_1.0: i32) }; return; } } diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-unwind.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-unwind.mir index b5768e14452cd..93447b1388dea 100644 --- a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-unwind.mir +++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-unwind.mir @@ -1,10 +1,10 @@ // MIR for `main::{closure#0}::{closure#0}` 0 coroutine_closure_by_move -fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:37:33: 37:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10} { - let mut _0: {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10}; +fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:42:33: 42:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10} { + let mut _0: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10}; bb0: { - _0 = {coroutine@$DIR/async_closure_shims.rs:37:53: 40:10 (#0)} { a: move _2, b: move (_1.0: i32) }; + _0 = {coroutine@$DIR/async_closure_shims.rs:42:53: 45:10 (#0)} { a: move _2, b: move (_1.0: i32) }; return; } } diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.panic-abort.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.panic-abort.mir new file mode 100644 index 0000000000000..f51540bcfff75 --- /dev/null +++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.panic-abort.mir @@ -0,0 +1,10 @@ +// MIR for `main::{closure#0}::{closure#1}` 0 coroutine_closure_by_ref + +fn main::{closure#0}::{closure#1}(_1: *mut {async closure@$DIR/async_closure_shims.rs:49:29: 49:48}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10} { + let mut _0: {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10}; + + bb0: { + _0 = {coroutine@$DIR/async_closure_shims.rs:49:49: 51:10 (#0)} { a: move _2 }; + return; + } +} diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.panic-unwind.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.panic-unwind.mir new file mode 100644 index 0000000000000..f51540bcfff75 --- /dev/null +++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.panic-unwind.mir @@ -0,0 +1,10 @@ +// MIR for `main::{closure#0}::{closure#1}` 0 coroutine_closure_by_ref + +fn main::{closure#0}::{closure#1}(_1: *mut {async closure@$DIR/async_closure_shims.rs:49:29: 49:48}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10} { + let mut _0: {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10}; + + bb0: { + _0 = {coroutine@$DIR/async_closure_shims.rs:49:49: 51:10 (#0)} { a: move _2 }; + return; + } +} diff --git a/tests/mir-opt/async_closure_shims.rs b/tests/mir-opt/async_closure_shims.rs index 47c41ed0500bd..7d226df686654 100644 --- a/tests/mir-opt/async_closure_shims.rs +++ b/tests/mir-opt/async_closure_shims.rs @@ -29,8 +29,13 @@ async fn call_once(f: impl AsyncFnOnce(i32)) { f(1).await; } +async fn call_normal>(f: &impl Fn(i32) -> F) { + f(1).await; +} + // EMIT_MIR async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.mir // EMIT_MIR async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.mir +// EMIT_MIR async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.mir pub fn main() { block_on(async { let b = 2i32; @@ -40,5 +45,10 @@ pub fn main() { }; call_mut(&mut async_closure).await; call_once(async_closure).await; + + let async_closure = async move |a: i32| { + let a = &a; + }; + call_normal(&async_closure).await; }); }