Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some more coroutine renamings #117419

Merged
merged 1 commit into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
closure_node_id: NodeId,
ret_ty: Option<hir::FnRetTy<'hir>>,
span: Span,
async_gen_kind: hir::CoroutineSource,
async_coroutine_source: hir::CoroutineSource,
body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
) -> hir::ExprKind<'hir> {
let output = ret_ty.unwrap_or_else(|| hir::FnRetTy::DefaultReturn(self.lower_span(span)));
Expand Down Expand Up @@ -645,7 +645,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
let params = arena_vec![self; param];

let body = self.lower_body(move |this| {
this.coroutine_kind = Some(hir::CoroutineKind::Async(async_gen_kind));
this.coroutine_kind = Some(hir::CoroutineKind::Async(async_coroutine_source));

let old_ctx = this.task_context;
this.task_context = Some(task_context_hid);
Expand Down Expand Up @@ -684,7 +684,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
closure_node_id: NodeId,
_yield_ty: Option<hir::FnRetTy<'hir>>,
span: Span,
gen_kind: hir::CoroutineSource,
coroutine_source: hir::CoroutineSource,
body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
) -> hir::ExprKind<'hir> {
let output = hir::FnRetTy::DefaultReturn(self.lower_span(span));
Expand All @@ -699,7 +699,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
});

let body = self.lower_body(move |this| {
this.coroutine_kind = Some(hir::CoroutineKind::Gen(gen_kind));
this.coroutine_kind = Some(hir::CoroutineKind::Gen(coroutine_source));

let res = body(this);
(&[], res)
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_ast_lowering/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -988,12 +988,12 @@ impl<'hir> LoweringContext<'_, 'hir> {
&mut self,
f: impl FnOnce(&mut Self) -> (&'hir [hir::Param<'hir>], hir::Expr<'hir>),
) -> hir::BodyId {
let prev_gen_kind = self.coroutine_kind.take();
let prev_coroutine_kind = self.coroutine_kind.take();
let task_context = self.task_context.take();
let (parameters, result) = f(self);
let body_id = self.record_body(parameters, result);
self.task_context = task_context;
self.coroutine_kind = prev_gen_kind;
self.coroutine_kind = prev_coroutine_kind;
body_id
}

Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ language_item_table! {

Iterator, sym::iterator, iterator_trait, Target::Trait, GenericRequirement::Exact(0);
Future, sym::future_trait, future_trait, Target::Trait, GenericRequirement::Exact(0);
CoroutineState, sym::coroutine_state, gen_state, Target::Enum, GenericRequirement::None;
Coroutine, sym::coroutine, gen_trait, Target::Trait, GenericRequirement::Minimum(1);
CoroutineState, sym::coroutine_state, coroutine_state, Target::Enum, GenericRequirement::None;
Coroutine, sym::coroutine, coroutine_trait, Target::Trait, GenericRequirement::Minimum(1);
Unpin, sym::unpin, unpin_trait, Target::Trait, GenericRequirement::None;
Pin, sym::pin, pin_type, Target::Struct, GenericRequirement::None;

Expand Down
8 changes: 5 additions & 3 deletions compiler/rustc_hir_typeck/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,16 @@ pub(super) fn check_fn<'a, 'tcx>(
// We insert the deferred_coroutine_interiors entry after visiting the body.
// This ensures that all nested coroutines appear before the entry of this coroutine.
// resolve_coroutine_interiors relies on this property.
let gen_ty = if let (Some(_), Some(gen_kind)) = (can_be_coroutine, body.coroutine_kind) {
let coroutine_ty = if let (Some(_), Some(coroutine_kind)) =
(can_be_coroutine, body.coroutine_kind)
{
let interior = fcx
.next_ty_var(TypeVariableOrigin { kind: TypeVariableOriginKind::MiscVariable, span });
fcx.deferred_coroutine_interiors.borrow_mut().push((
fn_def_id,
body.id(),
interior,
gen_kind,
coroutine_kind,
));

let (resume_ty, yield_ty) = fcx.resume_yield_tys.unwrap();
Expand Down Expand Up @@ -184,7 +186,7 @@ pub(super) fn check_fn<'a, 'tcx>(
check_lang_start_fn(tcx, fn_sig, fn_def_id);
}

gen_ty
coroutine_ty
}

fn check_panic_info_fn(tcx: TyCtxt<'_>, fn_id: LocalDefId, fn_sig: ty::FnSig<'_>) {
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {

let is_fn = tcx.is_fn_trait(trait_def_id);

let gen_trait = tcx.lang_items().gen_trait();
let is_gen = gen_trait == Some(trait_def_id);
let coroutine_trait = tcx.lang_items().coroutine_trait();
let is_gen = coroutine_trait == Some(trait_def_id);

if !is_fn && !is_gen {
debug!("not fn or coroutine");
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/print/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1053,7 +1053,7 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
let term = if let Some(ty) = term.skip_binder().ty()
&& let ty::Alias(ty::Projection, proj) = ty.kind()
&& let Some(assoc) = tcx.opt_associated_item(proj.def_id)
&& assoc.trait_container(tcx) == tcx.lang_items().gen_trait()
&& assoc.trait_container(tcx) == tcx.lang_items().coroutine_trait()
&& assoc.name == rustc_span::sym::Return
{
if let ty::Coroutine(_, args, _) = args.type_at(0).kind() {
Expand Down
8 changes: 4 additions & 4 deletions compiler/rustc_mir_build/src/build/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,14 +483,14 @@ fn construct_fn<'tcx>(
let arguments = &thir.params;

let (yield_ty, return_ty) = if coroutine_kind.is_some() {
let gen_ty = arguments[thir::UPVAR_ENV_PARAM].ty;
let gen_sig = match gen_ty.kind() {
let coroutine_ty = arguments[thir::UPVAR_ENV_PARAM].ty;
let coroutine_sig = match coroutine_ty.kind() {
ty::Coroutine(_, gen_args, ..) => gen_args.as_coroutine().sig(),
_ => {
span_bug!(span, "coroutine w/o coroutine type: {:?}", gen_ty)
span_bug!(span, "coroutine w/o coroutine type: {:?}", coroutine_ty)
}
};
(Some(gen_sig.yield_ty), gen_sig.return_ty)
(Some(coroutine_sig.yield_ty), coroutine_sig.return_ty)
} else {
(None, fn_sig.output())
};
Expand Down
13 changes: 9 additions & 4 deletions compiler/rustc_mir_build/src/thir/cx/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,15 @@ impl<'tcx> Cx<'tcx> {
Some(env_param)
}
DefKind::Coroutine => {
let gen_ty = self.typeck_results.node_type(owner_id);
let gen_param =
Param { ty: gen_ty, pat: None, ty_span: None, self_kind: None, hir_id: None };
Some(gen_param)
let coroutine_ty = self.typeck_results.node_type(owner_id);
let coroutine_param = Param {
ty: coroutine_ty,
pat: None,
ty_span: None,
self_kind: None,
hir_id: None,
};
Some(coroutine_param)
}
_ => None,
}
Expand Down
47 changes: 24 additions & 23 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ impl<'tcx> MutVisitor<'tcx> for DerefArgVisitor<'tcx> {
}

struct PinArgVisitor<'tcx> {
ref_gen_ty: Ty<'tcx>,
ref_coroutine_ty: Ty<'tcx>,
tcx: TyCtxt<'tcx>,
}

Expand All @@ -168,7 +168,7 @@ impl<'tcx> MutVisitor<'tcx> for PinArgVisitor<'tcx> {
local: SELF_ARG,
projection: self.tcx().mk_place_elems(&[ProjectionElem::Field(
FieldIdx::new(0),
self.ref_gen_ty,
self.ref_coroutine_ty,
)]),
},
self.tcx,
Expand Down Expand Up @@ -468,34 +468,34 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
}

fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let gen_ty = body.local_decls.raw[1].ty;
let coroutine_ty = body.local_decls.raw[1].ty;

let ref_gen_ty = Ty::new_ref(
let ref_coroutine_ty = Ty::new_ref(
tcx,
tcx.lifetimes.re_erased,
ty::TypeAndMut { ty: gen_ty, mutbl: Mutability::Mut },
ty::TypeAndMut { ty: coroutine_ty, mutbl: Mutability::Mut },
);

// Replace the by value coroutine argument
body.local_decls.raw[1].ty = ref_gen_ty;
body.local_decls.raw[1].ty = ref_coroutine_ty;

// Add a deref to accesses of the coroutine state
DerefArgVisitor { tcx }.visit_body(body);
}

fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let ref_gen_ty = body.local_decls.raw[1].ty;
let ref_coroutine_ty = body.local_decls.raw[1].ty;

let pin_did = tcx.require_lang_item(LangItem::Pin, Some(body.span));
let pin_adt_ref = tcx.adt_def(pin_did);
let args = tcx.mk_args(&[ref_gen_ty.into()]);
let pin_ref_gen_ty = Ty::new_adt(tcx, pin_adt_ref, args);
let args = tcx.mk_args(&[ref_coroutine_ty.into()]);
let pin_ref_coroutine_ty = Ty::new_adt(tcx, pin_adt_ref, args);

// Replace the by ref coroutine argument
body.local_decls.raw[1].ty = pin_ref_gen_ty;
body.local_decls.raw[1].ty = pin_ref_coroutine_ty;

// Add the Pin field access to accesses of the coroutine state
PinArgVisitor { ref_gen_ty, tcx }.visit_body(body);
PinArgVisitor { ref_coroutine_ty, tcx }.visit_body(body);
}

/// Allocates a new local and replaces all references of `local` with it. Returns the new local.
Expand Down Expand Up @@ -1104,7 +1104,7 @@ fn elaborate_coroutine_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
fn create_coroutine_drop_shim<'tcx>(
tcx: TyCtxt<'tcx>,
transform: &TransformVisitor<'tcx>,
gen_ty: Ty<'tcx>,
coroutine_ty: Ty<'tcx>,
body: &mut Body<'tcx>,
drop_clean: BasicBlock,
) -> Body<'tcx> {
Expand Down Expand Up @@ -1136,7 +1136,7 @@ fn create_coroutine_drop_shim<'tcx>(

// Change the coroutine argument from &mut to *mut
body.local_decls[SELF_ARG] = LocalDecl::with_source_info(
Ty::new_ptr(tcx, ty::TypeAndMut { ty: gen_ty, mutbl: hir::Mutability::Mut }),
Ty::new_ptr(tcx, ty::TypeAndMut { ty: coroutine_ty, mutbl: hir::Mutability::Mut }),
source_info,
);

Expand All @@ -1146,9 +1146,9 @@ fn create_coroutine_drop_shim<'tcx>(

// Update the body's def to become the drop glue.
// This needs to be updated before the AbortUnwindingCalls pass.
let gen_instance = body.source.instance;
let coroutine_instance = body.source.instance;
let drop_in_place = tcx.require_lang_item(LangItem::DropInPlace, None);
let drop_instance = InstanceDef::DropGlue(drop_in_place, Some(gen_ty));
let drop_instance = InstanceDef::DropGlue(drop_in_place, Some(coroutine_ty));
body.source.instance = drop_instance;

pm::run_passes_no_validate(
Expand All @@ -1160,7 +1160,7 @@ fn create_coroutine_drop_shim<'tcx>(

// Temporary change MirSource to coroutine's instance so that dump_mir produces more sensible
// filename.
body.source.instance = gen_instance;
body.source.instance = coroutine_instance;
dump_mir(tcx, false, "coroutine_drop", &0, &body, |_, _| Ok(()));
body.source.instance = drop_instance;

Expand Down Expand Up @@ -1447,13 +1447,13 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>(
let body = &*body;

// The first argument is the coroutine type passed by value
let gen_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;

// Get the interior types and args which typeck computed
let movable = match *gen_ty.kind() {
let movable = match *coroutine_ty.kind() {
ty::Coroutine(_, _, movability) => movability == hir::Movability::Movable,
ty::Error(_) => return None,
_ => span_bug!(body.span, "unexpected coroutine type {}", gen_ty),
_ => span_bug!(body.span, "unexpected coroutine type {}", coroutine_ty),
};

// When first entering the coroutine, move the resume argument into its new local.
Expand Down Expand Up @@ -1481,16 +1481,17 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
assert!(body.coroutine_drop().is_none());

// The first argument is the coroutine type passed by value
let gen_ty = body.local_decls.raw[1].ty;
let coroutine_ty = body.local_decls.raw[1].ty;

// Get the discriminant type and args which typeck computed
let (discr_ty, movable) = match *gen_ty.kind() {
let (discr_ty, movable) = match *coroutine_ty.kind() {
ty::Coroutine(_, args, movability) => {
let args = args.as_coroutine();
(args.discr_ty(tcx), movability == hir::Movability::Movable)
}
_ => {
tcx.sess.delay_span_bug(body.span, format!("unexpected coroutine type {gen_ty}"));
tcx.sess
.delay_span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}"));
return;
}
};
Expand Down Expand Up @@ -1626,7 +1627,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
dump_mir(tcx, false, "coroutine_post-transform", &0, body, |_, _| Ok(()));

// Create a copy of our MIR and use it to create the drop shim for the coroutine
let drop_shim = create_coroutine_drop_shim(tcx, &transform, gen_ty, body, drop_clean);
let drop_shim = create_coroutine_drop_shim(tcx, &transform, coroutine_ty, body, drop_clean);

body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim);

Expand Down
14 changes: 7 additions & 7 deletions compiler/rustc_mir_transform/src/shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
ty::InstanceDef::DropGlue(def_id, ty) => {
// FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end
// of this function. Is this intentional?
if let Some(ty::Coroutine(gen_def_id, args, _)) = ty.map(Ty::kind) {
let body = tcx.optimized_mir(*gen_def_id).coroutine_drop().unwrap();
if let Some(ty::Coroutine(coroutine_def_id, args, _)) = ty.map(Ty::kind) {
let body = tcx.optimized_mir(*coroutine_def_id).coroutine_drop().unwrap();
let mut body = EarlyBinder::bind(body.clone()).instantiate(tcx, args);
debug!("make_shim({:?}) = {:?}", instance, body);

Expand Down Expand Up @@ -392,8 +392,8 @@ fn build_clone_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) -
_ if is_copy => builder.copy_shim(),
ty::Closure(_, args) => builder.tuple_like_shim(dest, src, args.as_closure().upvar_tys()),
ty::Tuple(..) => builder.tuple_like_shim(dest, src, self_ty.tuple_fields()),
ty::Coroutine(gen_def_id, args, hir::Movability::Movable) => {
builder.coroutine_shim(dest, src, *gen_def_id, args.as_coroutine())
ty::Coroutine(coroutine_def_id, args, hir::Movability::Movable) => {
builder.coroutine_shim(dest, src, *coroutine_def_id, args.as_coroutine())
}
_ => bug!("clone shim for `{:?}` which is not `Copy` and is not an aggregate", self_ty),
};
Expand Down Expand Up @@ -597,7 +597,7 @@ impl<'tcx> CloneShimBuilder<'tcx> {
&mut self,
dest: Place<'tcx>,
src: Place<'tcx>,
gen_def_id: DefId,
coroutine_def_id: DefId,
args: CoroutineArgs<'tcx>,
) {
self.block(vec![], TerminatorKind::Goto { target: self.block_index_offset(3) }, false);
Expand All @@ -607,8 +607,8 @@ impl<'tcx> CloneShimBuilder<'tcx> {
let unwind = self.clone_fields(dest, src, switch, unwind, args.upvar_tys());
let target = self.block(vec![], TerminatorKind::Return, false);
let unreachable = self.block(vec![], TerminatorKind::Unreachable, false);
let mut cases = Vec::with_capacity(args.state_tys(gen_def_id, self.tcx).count());
for (index, state_tys) in args.state_tys(gen_def_id, self.tcx).enumerate() {
let mut cases = Vec::with_capacity(args.state_tys(coroutine_def_id, self.tcx).count());
for (index, state_tys) in args.state_tys(coroutine_def_id, self.tcx).enumerate() {
let variant_index = VariantIdx::new(index);
let dest = self.tcx.mk_place_downcast_unnamed(dest, variant_index);
let src = self.tcx.mk_place_downcast_unnamed(src, variant_index);
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_trait_selection/src/solve/assembly/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
G::consider_builtin_future_candidate(self, goal)
} else if lang_items.iterator_trait() == Some(trait_def_id) {
G::consider_builtin_iterator_candidate(self, goal)
} else if lang_items.gen_trait() == Some(trait_def_id) {
} else if lang_items.coroutine_trait() == Some(trait_def_id) {
G::consider_builtin_coroutine_candidate(self, goal)
} else if lang_items.discriminant_kind_trait() == Some(trait_def_id) {
G::consider_builtin_discriminant_kind_candidate(self, goal)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1648,7 +1648,7 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
}

fn describe_coroutine(&self, body_id: hir::BodyId) -> Option<&'static str> {
self.tcx.hir().body(body_id).coroutine_kind.map(|gen_kind| match gen_kind {
self.tcx.hir().body(body_id).coroutine_kind.map(|coroutine_source| match coroutine_source {
hir::CoroutineKind::Coroutine => "a coroutine",
hir::CoroutineKind::Async(hir::CoroutineSource::Block) => "an async block",
hir::CoroutineKind::Async(hir::CoroutineSource::Fn) => "an async function",
Expand Down Expand Up @@ -3187,7 +3187,8 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
// traits manually, but don't make it more confusing when it does
// happen.
Some(
if Some(expected_trait_ref.def_id()) != self.tcx.lang_items().gen_trait() && not_tupled
if Some(expected_trait_ref.def_id()) != self.tcx.lang_items().coroutine_trait()
&& not_tupled
{
self.report_and_explain_type_error(
TypeTrace::poly_trait_refs(
Expand Down
Loading
Loading