Skip to content

Commit

Permalink
Some more coroutine renamings
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Oct 30, 2023
1 parent 31bc7e2 commit add09e6
Show file tree
Hide file tree
Showing 16 changed files with 85 additions and 76 deletions.
8 changes: 4 additions & 4 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
closure_node_id: NodeId,
ret_ty: Option<hir::FnRetTy<'hir>>,
span: Span,
async_gen_kind: hir::CoroutineSource,
async_coroutine_source: hir::CoroutineSource,
body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
) -> hir::ExprKind<'hir> {
let output = ret_ty.unwrap_or_else(|| hir::FnRetTy::DefaultReturn(self.lower_span(span)));
Expand Down Expand Up @@ -645,7 +645,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
let params = arena_vec![self; param];

let body = self.lower_body(move |this| {
this.coroutine_kind = Some(hir::CoroutineKind::Async(async_gen_kind));
this.coroutine_kind = Some(hir::CoroutineKind::Async(async_coroutine_source));

let old_ctx = this.task_context;
this.task_context = Some(task_context_hid);
Expand Down Expand Up @@ -684,7 +684,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
closure_node_id: NodeId,
_yield_ty: Option<hir::FnRetTy<'hir>>,
span: Span,
gen_kind: hir::CoroutineSource,
coroutine_source: hir::CoroutineSource,
body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
) -> hir::ExprKind<'hir> {
let output = hir::FnRetTy::DefaultReturn(self.lower_span(span));
Expand All @@ -699,7 +699,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
});

let body = self.lower_body(move |this| {
this.coroutine_kind = Some(hir::CoroutineKind::Gen(gen_kind));
this.coroutine_kind = Some(hir::CoroutineKind::Gen(coroutine_source));

let res = body(this);
(&[], res)
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_ast_lowering/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -988,12 +988,12 @@ impl<'hir> LoweringContext<'_, 'hir> {
&mut self,
f: impl FnOnce(&mut Self) -> (&'hir [hir::Param<'hir>], hir::Expr<'hir>),
) -> hir::BodyId {
let prev_gen_kind = self.coroutine_kind.take();
let prev_coroutine_kind = self.coroutine_kind.take();
let task_context = self.task_context.take();
let (parameters, result) = f(self);
let body_id = self.record_body(parameters, result);
self.task_context = task_context;
self.coroutine_kind = prev_gen_kind;
self.coroutine_kind = prev_coroutine_kind;
body_id
}

Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ language_item_table! {

Iterator, sym::iterator, iterator_trait, Target::Trait, GenericRequirement::Exact(0);
Future, sym::future_trait, future_trait, Target::Trait, GenericRequirement::Exact(0);
CoroutineState, sym::coroutine_state, gen_state, Target::Enum, GenericRequirement::None;
Coroutine, sym::coroutine, gen_trait, Target::Trait, GenericRequirement::Minimum(1);
CoroutineState, sym::coroutine_state, coroutine_state, Target::Enum, GenericRequirement::None;
Coroutine, sym::coroutine, coroutine_trait, Target::Trait, GenericRequirement::Minimum(1);
Unpin, sym::unpin, unpin_trait, Target::Trait, GenericRequirement::None;
Pin, sym::pin, pin_type, Target::Struct, GenericRequirement::None;

Expand Down
8 changes: 5 additions & 3 deletions compiler/rustc_hir_typeck/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,16 @@ pub(super) fn check_fn<'a, 'tcx>(
// We insert the deferred_coroutine_interiors entry after visiting the body.
// This ensures that all nested coroutines appear before the entry of this coroutine.
// resolve_coroutine_interiors relies on this property.
let gen_ty = if let (Some(_), Some(gen_kind)) = (can_be_coroutine, body.coroutine_kind) {
let coroutine_ty = if let (Some(_), Some(coroutine_kind)) =
(can_be_coroutine, body.coroutine_kind)
{
let interior = fcx
.next_ty_var(TypeVariableOrigin { kind: TypeVariableOriginKind::MiscVariable, span });
fcx.deferred_coroutine_interiors.borrow_mut().push((
fn_def_id,
body.id(),
interior,
gen_kind,
coroutine_kind,
));

let (resume_ty, yield_ty) = fcx.resume_yield_tys.unwrap();
Expand Down Expand Up @@ -184,7 +186,7 @@ pub(super) fn check_fn<'a, 'tcx>(
check_lang_start_fn(tcx, fn_sig, fn_def_id);
}

gen_ty
coroutine_ty
}

fn check_panic_info_fn(tcx: TyCtxt<'_>, fn_id: LocalDefId, fn_sig: ty::FnSig<'_>) {
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {

let is_fn = tcx.is_fn_trait(trait_def_id);

let gen_trait = tcx.lang_items().gen_trait();
let is_gen = gen_trait == Some(trait_def_id);
let coroutine_trait = tcx.lang_items().coroutine_trait();
let is_gen = coroutine_trait == Some(trait_def_id);

if !is_fn && !is_gen {
debug!("not fn or coroutine");
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/print/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1053,7 +1053,7 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
let term = if let Some(ty) = term.skip_binder().ty()
&& let ty::Alias(ty::Projection, proj) = ty.kind()
&& let Some(assoc) = tcx.opt_associated_item(proj.def_id)
&& assoc.trait_container(tcx) == tcx.lang_items().gen_trait()
&& assoc.trait_container(tcx) == tcx.lang_items().coroutine_trait()
&& assoc.name == rustc_span::sym::Return
{
if let ty::Coroutine(_, args, _) = args.type_at(0).kind() {
Expand Down
8 changes: 4 additions & 4 deletions compiler/rustc_mir_build/src/build/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,14 +483,14 @@ fn construct_fn<'tcx>(
let arguments = &thir.params;

let (yield_ty, return_ty) = if coroutine_kind.is_some() {
let gen_ty = arguments[thir::UPVAR_ENV_PARAM].ty;
let gen_sig = match gen_ty.kind() {
let coroutine_ty = arguments[thir::UPVAR_ENV_PARAM].ty;
let coroutine_sig = match coroutine_ty.kind() {
ty::Coroutine(_, gen_args, ..) => gen_args.as_coroutine().sig(),
_ => {
span_bug!(span, "coroutine w/o coroutine type: {:?}", gen_ty)
span_bug!(span, "coroutine w/o coroutine type: {:?}", coroutine_ty)
}
};
(Some(gen_sig.yield_ty), gen_sig.return_ty)
(Some(coroutine_sig.yield_ty), coroutine_sig.return_ty)
} else {
(None, fn_sig.output())
};
Expand Down
13 changes: 9 additions & 4 deletions compiler/rustc_mir_build/src/thir/cx/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,15 @@ impl<'tcx> Cx<'tcx> {
Some(env_param)
}
DefKind::Coroutine => {
let gen_ty = self.typeck_results.node_type(owner_id);
let gen_param =
Param { ty: gen_ty, pat: None, ty_span: None, self_kind: None, hir_id: None };
Some(gen_param)
let coroutine_ty = self.typeck_results.node_type(owner_id);
let coroutine_param = Param {
ty: coroutine_ty,
pat: None,
ty_span: None,
self_kind: None,
hir_id: None,
};
Some(coroutine_param)
}
_ => None,
}
Expand Down
47 changes: 24 additions & 23 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ impl<'tcx> MutVisitor<'tcx> for DerefArgVisitor<'tcx> {
}

struct PinArgVisitor<'tcx> {
ref_gen_ty: Ty<'tcx>,
ref_coroutine_ty: Ty<'tcx>,
tcx: TyCtxt<'tcx>,
}

Expand All @@ -168,7 +168,7 @@ impl<'tcx> MutVisitor<'tcx> for PinArgVisitor<'tcx> {
local: SELF_ARG,
projection: self.tcx().mk_place_elems(&[ProjectionElem::Field(
FieldIdx::new(0),
self.ref_gen_ty,
self.ref_coroutine_ty,
)]),
},
self.tcx,
Expand Down Expand Up @@ -468,34 +468,34 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
}

fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let gen_ty = body.local_decls.raw[1].ty;
let coroutine_ty = body.local_decls.raw[1].ty;

let ref_gen_ty = Ty::new_ref(
let ref_coroutine_ty = Ty::new_ref(
tcx,
tcx.lifetimes.re_erased,
ty::TypeAndMut { ty: gen_ty, mutbl: Mutability::Mut },
ty::TypeAndMut { ty: coroutine_ty, mutbl: Mutability::Mut },
);

// Replace the by value coroutine argument
body.local_decls.raw[1].ty = ref_gen_ty;
body.local_decls.raw[1].ty = ref_coroutine_ty;

// Add a deref to accesses of the coroutine state
DerefArgVisitor { tcx }.visit_body(body);
}

fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let ref_gen_ty = body.local_decls.raw[1].ty;
let ref_coroutine_ty = body.local_decls.raw[1].ty;

let pin_did = tcx.require_lang_item(LangItem::Pin, Some(body.span));
let pin_adt_ref = tcx.adt_def(pin_did);
let args = tcx.mk_args(&[ref_gen_ty.into()]);
let pin_ref_gen_ty = Ty::new_adt(tcx, pin_adt_ref, args);
let args = tcx.mk_args(&[ref_coroutine_ty.into()]);
let pin_ref_coroutine_ty = Ty::new_adt(tcx, pin_adt_ref, args);

// Replace the by ref coroutine argument
body.local_decls.raw[1].ty = pin_ref_gen_ty;
body.local_decls.raw[1].ty = pin_ref_coroutine_ty;

// Add the Pin field access to accesses of the coroutine state
PinArgVisitor { ref_gen_ty, tcx }.visit_body(body);
PinArgVisitor { ref_coroutine_ty, tcx }.visit_body(body);
}

/// Allocates a new local and replaces all references of `local` with it. Returns the new local.
Expand Down Expand Up @@ -1104,7 +1104,7 @@ fn elaborate_coroutine_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
fn create_coroutine_drop_shim<'tcx>(
tcx: TyCtxt<'tcx>,
transform: &TransformVisitor<'tcx>,
gen_ty: Ty<'tcx>,
coroutine_ty: Ty<'tcx>,
body: &mut Body<'tcx>,
drop_clean: BasicBlock,
) -> Body<'tcx> {
Expand Down Expand Up @@ -1136,7 +1136,7 @@ fn create_coroutine_drop_shim<'tcx>(

// Change the coroutine argument from &mut to *mut
body.local_decls[SELF_ARG] = LocalDecl::with_source_info(
Ty::new_ptr(tcx, ty::TypeAndMut { ty: gen_ty, mutbl: hir::Mutability::Mut }),
Ty::new_ptr(tcx, ty::TypeAndMut { ty: coroutine_ty, mutbl: hir::Mutability::Mut }),
source_info,
);

Expand All @@ -1146,9 +1146,9 @@ fn create_coroutine_drop_shim<'tcx>(

// Update the body's def to become the drop glue.
// This needs to be updated before the AbortUnwindingCalls pass.
let gen_instance = body.source.instance;
let coroutine_instance = body.source.instance;
let drop_in_place = tcx.require_lang_item(LangItem::DropInPlace, None);
let drop_instance = InstanceDef::DropGlue(drop_in_place, Some(gen_ty));
let drop_instance = InstanceDef::DropGlue(drop_in_place, Some(coroutine_ty));
body.source.instance = drop_instance;

pm::run_passes_no_validate(
Expand All @@ -1160,7 +1160,7 @@ fn create_coroutine_drop_shim<'tcx>(

// Temporary change MirSource to coroutine's instance so that dump_mir produces more sensible
// filename.
body.source.instance = gen_instance;
body.source.instance = coroutine_instance;
dump_mir(tcx, false, "coroutine_drop", &0, &body, |_, _| Ok(()));
body.source.instance = drop_instance;

Expand Down Expand Up @@ -1447,13 +1447,13 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>(
let body = &*body;

// The first argument is the coroutine type passed by value
let gen_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;

// Get the interior types and args which typeck computed
let movable = match *gen_ty.kind() {
let movable = match *coroutine_ty.kind() {
ty::Coroutine(_, _, movability) => movability == hir::Movability::Movable,
ty::Error(_) => return None,
_ => span_bug!(body.span, "unexpected coroutine type {}", gen_ty),
_ => span_bug!(body.span, "unexpected coroutine type {}", coroutine_ty),
};

// When first entering the coroutine, move the resume argument into its new local.
Expand Down Expand Up @@ -1481,16 +1481,17 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
assert!(body.coroutine_drop().is_none());

// The first argument is the coroutine type passed by value
let gen_ty = body.local_decls.raw[1].ty;
let coroutine_ty = body.local_decls.raw[1].ty;

// Get the discriminant type and args which typeck computed
let (discr_ty, movable) = match *gen_ty.kind() {
let (discr_ty, movable) = match *coroutine_ty.kind() {
ty::Coroutine(_, args, movability) => {
let args = args.as_coroutine();
(args.discr_ty(tcx), movability == hir::Movability::Movable)
}
_ => {
tcx.sess.delay_span_bug(body.span, format!("unexpected coroutine type {gen_ty}"));
tcx.sess
.delay_span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}"));
return;
}
};
Expand Down Expand Up @@ -1626,7 +1627,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
dump_mir(tcx, false, "coroutine_post-transform", &0, body, |_, _| Ok(()));

// Create a copy of our MIR and use it to create the drop shim for the coroutine
let drop_shim = create_coroutine_drop_shim(tcx, &transform, gen_ty, body, drop_clean);
let drop_shim = create_coroutine_drop_shim(tcx, &transform, coroutine_ty, body, drop_clean);

body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim);

Expand Down
14 changes: 7 additions & 7 deletions compiler/rustc_mir_transform/src/shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
ty::InstanceDef::DropGlue(def_id, ty) => {
// FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end
// of this function. Is this intentional?
if let Some(ty::Coroutine(gen_def_id, args, _)) = ty.map(Ty::kind) {
let body = tcx.optimized_mir(*gen_def_id).coroutine_drop().unwrap();
if let Some(ty::Coroutine(coroutine_def_id, args, _)) = ty.map(Ty::kind) {
let body = tcx.optimized_mir(*coroutine_def_id).coroutine_drop().unwrap();
let mut body = EarlyBinder::bind(body.clone()).instantiate(tcx, args);
debug!("make_shim({:?}) = {:?}", instance, body);

Expand Down Expand Up @@ -392,8 +392,8 @@ fn build_clone_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) -
_ if is_copy => builder.copy_shim(),
ty::Closure(_, args) => builder.tuple_like_shim(dest, src, args.as_closure().upvar_tys()),
ty::Tuple(..) => builder.tuple_like_shim(dest, src, self_ty.tuple_fields()),
ty::Coroutine(gen_def_id, args, hir::Movability::Movable) => {
builder.coroutine_shim(dest, src, *gen_def_id, args.as_coroutine())
ty::Coroutine(coroutine_def_id, args, hir::Movability::Movable) => {
builder.coroutine_shim(dest, src, *coroutine_def_id, args.as_coroutine())
}
_ => bug!("clone shim for `{:?}` which is not `Copy` and is not an aggregate", self_ty),
};
Expand Down Expand Up @@ -597,7 +597,7 @@ impl<'tcx> CloneShimBuilder<'tcx> {
&mut self,
dest: Place<'tcx>,
src: Place<'tcx>,
gen_def_id: DefId,
coroutine_def_id: DefId,
args: CoroutineArgs<'tcx>,
) {
self.block(vec![], TerminatorKind::Goto { target: self.block_index_offset(3) }, false);
Expand All @@ -607,8 +607,8 @@ impl<'tcx> CloneShimBuilder<'tcx> {
let unwind = self.clone_fields(dest, src, switch, unwind, args.upvar_tys());
let target = self.block(vec![], TerminatorKind::Return, false);
let unreachable = self.block(vec![], TerminatorKind::Unreachable, false);
let mut cases = Vec::with_capacity(args.state_tys(gen_def_id, self.tcx).count());
for (index, state_tys) in args.state_tys(gen_def_id, self.tcx).enumerate() {
let mut cases = Vec::with_capacity(args.state_tys(coroutine_def_id, self.tcx).count());
for (index, state_tys) in args.state_tys(coroutine_def_id, self.tcx).enumerate() {
let variant_index = VariantIdx::new(index);
let dest = self.tcx.mk_place_downcast_unnamed(dest, variant_index);
let src = self.tcx.mk_place_downcast_unnamed(src, variant_index);
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_trait_selection/src/solve/assembly/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
G::consider_builtin_future_candidate(self, goal)
} else if lang_items.iterator_trait() == Some(trait_def_id) {
G::consider_builtin_iterator_candidate(self, goal)
} else if lang_items.gen_trait() == Some(trait_def_id) {
} else if lang_items.coroutine_trait() == Some(trait_def_id) {
G::consider_builtin_coroutine_candidate(self, goal)
} else if lang_items.discriminant_kind_trait() == Some(trait_def_id) {
G::consider_builtin_discriminant_kind_candidate(self, goal)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1648,7 +1648,7 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
}

fn describe_coroutine(&self, body_id: hir::BodyId) -> Option<&'static str> {
self.tcx.hir().body(body_id).coroutine_kind.map(|gen_kind| match gen_kind {
self.tcx.hir().body(body_id).coroutine_kind.map(|coroutine_source| match coroutine_source {
hir::CoroutineKind::Coroutine => "a coroutine",
hir::CoroutineKind::Async(hir::CoroutineSource::Block) => "an async block",
hir::CoroutineKind::Async(hir::CoroutineSource::Fn) => "an async function",
Expand Down Expand Up @@ -3187,7 +3187,8 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
// traits manually, but don't make it more confusing when it does
// happen.
Some(
if Some(expected_trait_ref.def_id()) != self.tcx.lang_items().gen_trait() && not_tupled
if Some(expected_trait_ref.def_id()) != self.tcx.lang_items().coroutine_trait()
&& not_tupled
{
self.report_and_explain_type_error(
TypeTrace::poly_trait_refs(
Expand Down
Loading

0 comments on commit add09e6

Please sign in to comment.