Skip to content

Commit

Permalink
WIP: Remove ResumeTy from async lowering
Browse files Browse the repository at this point in the history
Instead of using the stdlib supported `ResumeTy`, which is being converting to a `&mut Context<'_>` during the Generator MIR pass,
this will use `&mut Context<'_>` directly in HIR lowering.

It pretty much reverts rust-lang#105977 and re-applies an updated version of rust-lang#105250.

This still fails the testcase added in rust-lang#106264 however, for reasons I don’t understand.
  • Loading branch information
Swatinem committed Dec 28, 2023
1 parent df5d535 commit 90ef563
Show file tree
Hide file tree
Showing 12 changed files with 92 additions and 218 deletions.
89 changes: 48 additions & 41 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,17 +632,27 @@ impl<'hir> LoweringContext<'_, 'hir> {
) -> hir::ExprKind<'hir> {
let output = ret_ty.unwrap_or_else(|| hir::FnRetTy::DefaultReturn(self.lower_span(span)));

// Resume argument type: `ResumeTy`
let unstable_span = self.mark_span_with_reason(
DesugaringKind::Async,
self.lower_span(span),
Some(self.allow_gen_future.clone()),
);
let resume_ty = self.make_lang_item_qpath(hir::LangItem::ResumeTy, unstable_span);
// Resume argument type: `&mut Context<'_>`.
let context_lifetime_ident = Ident::with_dummy_span(kw::UnderscoreLifetime);
let context_lifetime = self.arena.alloc(hir::Lifetime {
hir_id: self.next_id(),
ident: context_lifetime_ident,
res: hir::LifetimeName::Infer,
});
let context_path = hir::QPath::LangItem(hir::LangItem::Context, self.lower_span(span));
let context_ty = hir::MutTy {
ty: self.arena.alloc(hir::Ty {
hir_id: self.next_id(),
kind: hir::TyKind::Path(context_path),
span: self.lower_span(span),
}),
mutbl: hir::Mutability::Mut,
};

let input_ty = hir::Ty {
hir_id: self.next_id(),
kind: hir::TyKind::Path(resume_ty),
span: unstable_span,
kind: hir::TyKind::Ref(context_lifetime, context_ty),
span: self.lower_span(span),
};

// The closure/coroutine `FnDecl` takes a single (resume) argument of type `input_ty`.
Expand Down Expand Up @@ -768,17 +778,27 @@ impl<'hir> LoweringContext<'_, 'hir> {
) -> hir::ExprKind<'hir> {
let output = hir::FnRetTy::DefaultReturn(self.lower_span(span));

// Resume argument type: `ResumeTy`
let unstable_span = self.mark_span_with_reason(
DesugaringKind::Async,
self.lower_span(span),
Some(self.allow_gen_future.clone()),
);
let resume_ty = self.make_lang_item_qpath(hir::LangItem::ResumeTy, unstable_span);
// Resume argument type: `&mut Context<'_>`.
let context_lifetime_ident = Ident::with_dummy_span(kw::UnderscoreLifetime);
let context_lifetime = self.arena.alloc(hir::Lifetime {
hir_id: self.next_id(),
ident: context_lifetime_ident,
res: hir::LifetimeName::Infer,
});
let context_path = hir::QPath::LangItem(hir::LangItem::Context, self.lower_span(span));
let context_ty = hir::MutTy {
ty: self.arena.alloc(hir::Ty {
hir_id: self.next_id(),
kind: hir::TyKind::Path(context_path),
span: self.lower_span(span),
}),
mutbl: hir::Mutability::Mut,
};

let input_ty = hir::Ty {
hir_id: self.next_id(),
kind: hir::TyKind::Path(resume_ty),
span: unstable_span,
kind: hir::TyKind::Ref(context_lifetime, context_ty),
span: self.lower_span(span),
};

// The closure/coroutine `FnDecl` takes a single (resume) argument of type `input_ty`.
Expand Down Expand Up @@ -871,7 +891,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
/// mut __awaitee => loop {
/// match unsafe { ::std::future::Future::poll(
/// <::std::pin::Pin>::new_unchecked(&mut __awaitee),
/// ::std::future::get_context(task_context),
/// task_context,
/// ) } {
/// ::std::task::Poll::Ready(result) => break result,
/// ::std::task::Poll::Pending => {}
Expand Down Expand Up @@ -912,29 +932,21 @@ impl<'hir> LoweringContext<'_, 'hir> {
FutureKind::AsyncIterator => Some(self.allow_for_await.clone()),
};
let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, features);
let gen_future_span = self.mark_span_with_reason(
DesugaringKind::Await,
full_span,
Some(self.allow_gen_future.clone()),
);
let expr_hir_id = expr.hir_id;

// Note that the name of this binding must not be changed to something else because
// debuggers and debugger extensions expect it to be called `__awaitee`. They use
// this name to identify what is being awaited by a suspended async functions.
let awaitee_ident = Ident::with_dummy_span(sym::__awaitee);
let (awaitee_pat, awaitee_pat_hid) = self.pat_ident_binding_mode(
gen_future_span,
awaitee_ident,
hir::BindingAnnotation::MUT,
);
let (awaitee_pat, awaitee_pat_hid) =
self.pat_ident_binding_mode(full_span, awaitee_ident, hir::BindingAnnotation::MUT);

let task_context_ident = Ident::with_dummy_span(sym::_task_context);

// unsafe {
// ::std::future::Future::poll(
// ::std::pin::Pin::new_unchecked(&mut __awaitee),
// ::std::future::get_context(task_context),
// task_context,
// )
// }
let poll_expr = {
Expand All @@ -952,21 +964,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
hir::LangItem::PinNewUnchecked,
arena_vec![self; ref_mut_awaitee],
);
let get_context = self.expr_call_lang_item_fn_mut(
gen_future_span,
hir::LangItem::GetContext,
arena_vec![self; task_context],
);
let call = match await_kind {
FutureKind::Future => self.expr_call_lang_item_fn(
span,
hir::LangItem::FuturePoll,
arena_vec![self; new_unchecked, get_context],
arena_vec![self; new_unchecked, task_context],
),
FutureKind::AsyncIterator => self.expr_call_lang_item_fn(
span,
hir::LangItem::AsyncIteratorPollNext,
arena_vec![self; new_unchecked, get_context],
arena_vec![self; new_unchecked, task_context],
),
};
self.arena.alloc(self.expr_unsafe(call))
Expand All @@ -977,14 +984,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
let loop_hir_id = self.lower_node_id(loop_node_id);
let ready_arm = {
let x_ident = Ident::with_dummy_span(sym::result);
let (x_pat, x_pat_hid) = self.pat_ident(gen_future_span, x_ident);
let x_expr = self.expr_ident(gen_future_span, x_ident, x_pat_hid);
let ready_field = self.single_pat_field(gen_future_span, x_pat);
let (x_pat, x_pat_hid) = self.pat_ident(full_span, x_ident);
let x_expr = self.expr_ident(full_span, x_ident, x_pat_hid);
let ready_field = self.single_pat_field(full_span, x_pat);
let ready_pat = self.pat_lang_item_variant(span, hir::LangItem::PollReady, ready_field);
let break_x = self.with_loop_scope(loop_node_id, move |this| {
let expr_break =
hir::ExprKind::Break(this.lower_loop_destination(None), Some(x_expr));
this.arena.alloc(this.expr(gen_future_span, expr_break))
this.arena.alloc(this.expr(full_span, expr_break))
});
self.arm(ready_pat, break_x)
};
Expand Down
5 changes: 0 additions & 5 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,6 @@ language_item_table! {
AsyncGenPending, sym::AsyncGenPending, async_gen_pending, Target::AssocConst, GenericRequirement::Exact(1);
AsyncGenFinished, sym::AsyncGenFinished, async_gen_finished, Target::AssocConst, GenericRequirement::Exact(1);

// FIXME(swatinem): the following lang items are used for async lowering and
// should become obsolete eventually.
ResumeTy, sym::ResumeTy, resume_ty, Target::Struct, GenericRequirement::None;
GetContext, sym::get_context, get_context_fn, Target::Fn, GenericRequirement::None;

Context, sym::Context, context, Target::Struct, GenericRequirement::None;
FuturePoll, sym::poll, future_poll_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;

Expand Down
9 changes: 0 additions & 9 deletions compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2245,15 +2245,6 @@ impl<'tcx> Ty<'tcx> {
let def_id = tcx.require_lang_item(LangItem::MaybeUninit, None);
Ty::new_generic_adt(tcx, def_id, ty)
}

/// Creates a `&mut Context<'_>` [`Ty`] with erased lifetimes.
pub fn new_task_context(tcx: TyCtxt<'tcx>) -> Ty<'tcx> {
let context_did = tcx.require_lang_item(LangItem::Context, None);
let context_adt_ref = tcx.adt_def(context_did);
let context_args = tcx.mk_args(&[tcx.lifetimes.re_erased.into()]);
let context_ty = Ty::new_adt(tcx, context_adt_ref, context_args);
Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, context_ty)
}
}

/// Type utilities
Expand Down
121 changes: 3 additions & 118 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,112 +616,14 @@ fn replace_local<'tcx>(
new_local
}

/// Transforms the `body` of the coroutine applying the following transforms:
///
/// - Eliminates all the `get_context` calls that async lowering created.
/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
///
/// The `Local`s that have their types replaced are:
/// - The `resume` argument itself.
/// - The argument to `get_context`.
/// - The yielded value of a `yield`.
///
/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
///
/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
/// but rather directly use `&mut Context<'_>`, however that would currently
/// lead to higher-kinded lifetime errors.
/// See <https://github.com/rust-lang/rust/issues/105501>.
///
/// The async lowering step and the type / lifetime inference / checking are
/// still using the `ResumeTy` indirection for the time being, and that indirection
/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`.
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let context_mut_ref = Ty::new_task_context(tcx);

// replace the type of the `resume` argument
replace_resume_ty_local(tcx, body, Local::new(2), context_mut_ref);

let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None);

for bb in START_BLOCK..body.basic_blocks.next_index() {
let bb_data = &body[bb];
if bb_data.is_cleanup {
continue;
}

match &bb_data.terminator().kind {
TerminatorKind::Call { func, .. } => {
let func_ty = func.ty(body, tcx);
if let ty::FnDef(def_id, _) = *func_ty.kind() {
if def_id == get_context_def_id {
let local = eliminate_get_context_call(&mut body[bb]);
replace_resume_ty_local(tcx, body, local, context_mut_ref);
}
} else {
continue;
}
}
TerminatorKind::Yield { resume_arg, .. } => {
replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
}
_ => {}
}
}
}

fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
let terminator = bb_data.terminator.take().unwrap();
if let TerminatorKind::Call { mut args, destination, target, .. } = terminator.kind {
let arg = args.pop().unwrap();
let local = arg.place().unwrap().local;

let arg = Rvalue::Use(arg);
let assign = Statement {
source_info: terminator.source_info,
kind: StatementKind::Assign(Box::new((destination, arg))),
};
bb_data.statements.push(assign);
bb_data.terminator = Some(Terminator {
source_info: terminator.source_info,
kind: TerminatorKind::Goto { target: target.unwrap() },
});
local
} else {
bug!();
}
}

#[cfg_attr(not(debug_assertions), allow(unused))]
fn replace_resume_ty_local<'tcx>(
tcx: TyCtxt<'tcx>,
body: &mut Body<'tcx>,
local: Local,
context_mut_ref: Ty<'tcx>,
) {
let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
// We have to replace the `ResumeTy` that is used for type and borrow checking
// with `&mut Context<'_>` in MIR.
#[cfg(debug_assertions)]
{
if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
assert_eq!(*resume_ty_adt, expected_adt);
} else {
panic!("expected `ResumeTy`, found `{:?}`", local_ty);
};
}
}

/// Transforms the `body` of the coroutine applying the following transform:
///
/// - Remove the `resume` argument.
///
/// Ideally the async lowering would not add the `resume` argument.
///
/// The async lowering step and the type / lifetime inference / checking are
/// still using the `resume` argument for the time being. After this transform,
/// still using the `resume` argument for the time being. After this transform
/// the coroutine body doesn't have the `resume` argument.
fn transform_gen_context<'tcx>(_tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
// This leaves the local representing the `resume` argument in place,
Expand Down Expand Up @@ -1613,14 +1515,6 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
}
};

let is_async_kind = matches!(
body.coroutine_kind(),
Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, _))
);
let is_async_gen_kind = matches!(
body.coroutine_kind(),
Some(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _))
);
let is_gen_kind = matches!(
body.coroutine_kind(),
Some(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _))
Expand Down Expand Up @@ -1657,22 +1551,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
// RETURN_PLACE then is a fresh unused local with type ret_ty.
let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx);

// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
if is_async_kind || is_async_gen_kind {
transform_async_context(tcx, body);
}

// We also replace the resume argument and insert an `Assign`.
// This is needed because the resume argument `_2` might be live across a `yield`, in which
// case there is no `Assign` to it that the transform can turn into a store to the coroutine
// state. After the yield the slot in the coroutine state would then be uninitialized.
let resume_local = Local::new(2);
let resume_ty = if is_async_kind {
Ty::new_task_context(tcx)
} else {
body.local_decls[resume_local].ty
};
let old_resume_local = replace_local(resume_local, resume_ty, body, tcx);
let old_resume_local =
replace_local(resume_local, body.local_decls[resume_local].ty, body, tcx);

// When first entering the coroutine, move the resume argument into its old local
// (which is now a generator interior).
Expand Down
2 changes: 0 additions & 2 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ symbols! {
Relaxed,
Release,
Result,
ResumeTy,
Return,
Right,
Rust,
Expand Down Expand Up @@ -833,7 +832,6 @@ symbols! {
generic_const_exprs,
generic_const_items,
generic_param_attrs,
get_context,
global_allocator,
global_asm,
globs,
Expand Down
Loading

0 comments on commit 90ef563

Please sign in to comment.