diff --git a/compiler/rustc_borrowck/src/type_check/input_output.rs b/compiler/rustc_borrowck/src/type_check/input_output.rs index 5bd7cc9514ca2..61b6bef3b87b9 100644 --- a/compiler/rustc_borrowck/src/type_check/input_output.rs +++ b/compiler/rustc_borrowck/src/type_check/input_output.rs @@ -94,31 +94,22 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> { ); } - debug!( - "equate_inputs_and_outputs: body.yield_ty {:?}, universal_regions.yield_ty {:?}", - body.yield_ty(), - universal_regions.yield_ty - ); - - // We will not have a universal_regions.yield_ty if we yield (by accident) - // outside of a coroutine and return an `impl Trait`, so emit a span_delayed_bug - // because we don't want to panic in an assert here if we've already got errors. - if body.yield_ty().is_some() != universal_regions.yield_ty.is_some() { - self.tcx().dcx().span_delayed_bug( - body.span, - format!( - "Expected body to have yield_ty ({:?}) iff we have a UR yield_ty ({:?})", - body.yield_ty(), - universal_regions.yield_ty, - ), + if let Some(mir_yield_ty) = body.yield_ty() { + let yield_span = body.local_decls[RETURN_PLACE].source_info.span; + self.equate_normalized_input_or_output( + universal_regions.yield_ty.unwrap(), + mir_yield_ty, + yield_span, ); } - if let (Some(mir_yield_ty), Some(ur_yield_ty)) = - (body.yield_ty(), universal_regions.yield_ty) - { + if let Some(mir_resume_ty) = body.resume_ty() { let yield_span = body.local_decls[RETURN_PLACE].source_info.span; - self.equate_normalized_input_or_output(ur_yield_ty, mir_yield_ty, yield_span); + self.equate_normalized_input_or_output( + universal_regions.resume_ty.unwrap(), + mir_resume_ty, + yield_span, + ); } // Return types are a bit more complex. They may contain opaque `impl Trait` types. diff --git a/compiler/rustc_borrowck/src/type_check/liveness/mod.rs b/compiler/rustc_borrowck/src/type_check/liveness/mod.rs index dc4695fd2b058..e137bc1be0aeb 100644 --- a/compiler/rustc_borrowck/src/type_check/liveness/mod.rs +++ b/compiler/rustc_borrowck/src/type_check/liveness/mod.rs @@ -183,6 +183,7 @@ impl<'cx, 'tcx> Visitor<'tcx> for LiveVariablesVisitor<'cx, 'tcx> { match ty_context { TyContext::ReturnTy(SourceInfo { span, .. }) | TyContext::YieldTy(SourceInfo { span, .. }) + | TyContext::ResumeTy(SourceInfo { span, .. }) | TyContext::UserTy(span) | TyContext::LocalDecl { source_info: SourceInfo { span, .. }, .. } => { span_bug!(span, "should not be visiting outside of the CFG: {:?}", ty_context); diff --git a/compiler/rustc_borrowck/src/type_check/mod.rs b/compiler/rustc_borrowck/src/type_check/mod.rs index 80575e30a8d23..9c0f53ddb86fa 100644 --- a/compiler/rustc_borrowck/src/type_check/mod.rs +++ b/compiler/rustc_borrowck/src/type_check/mod.rs @@ -1450,13 +1450,13 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> { } } } - TerminatorKind::Yield { value, .. } => { + TerminatorKind::Yield { value, resume_arg, .. } => { self.check_operand(value, term_location); - let value_ty = value.ty(body, tcx); match body.yield_ty() { None => span_mirbug!(self, term, "yield in non-coroutine"), Some(ty) => { + let value_ty = value.ty(body, tcx); if let Err(terr) = self.sub_types( value_ty, ty, @@ -1474,6 +1474,28 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> { } } } + + match body.resume_ty() { + None => span_mirbug!(self, term, "yield in non-coroutine"), + Some(ty) => { + let resume_ty = resume_arg.ty(body, tcx); + if let Err(terr) = self.sub_types( + ty, + resume_ty.ty, + term_location.to_locations(), + ConstraintCategory::Yield, + ) { + span_mirbug!( + self, + term, + "type of resume place is {:?}, but the resume type is {:?}: {:?}", + resume_ty, + ty, + terr + ); + } + } + } } } } diff --git a/compiler/rustc_borrowck/src/universal_regions.rs b/compiler/rustc_borrowck/src/universal_regions.rs index a02304a2f8b30..addb41ff5fc8f 100644 --- a/compiler/rustc_borrowck/src/universal_regions.rs +++ b/compiler/rustc_borrowck/src/universal_regions.rs @@ -76,6 +76,8 @@ pub struct UniversalRegions<'tcx> { pub unnormalized_input_tys: &'tcx [Ty<'tcx>], pub yield_ty: Option>, + + pub resume_ty: Option>, } /// The "defining type" for this MIR. The key feature of the "defining @@ -525,9 +527,12 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> { debug!("build: extern regions = {}..{}", first_extern_index, first_local_index); debug!("build: local regions = {}..{}", first_local_index, num_universals); - let yield_ty = match defining_ty { - DefiningTy::Coroutine(_, args) => Some(args.as_coroutine().yield_ty()), - _ => None, + let (resume_ty, yield_ty) = match defining_ty { + DefiningTy::Coroutine(_, args) => { + let tys = args.as_coroutine(); + (Some(tys.resume_ty()), Some(tys.yield_ty())) + } + _ => (None, None), }; UniversalRegions { @@ -541,6 +546,7 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> { unnormalized_output_ty: *unnormalized_output_ty, unnormalized_input_tys, yield_ty, + resume_ty, } } diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs index 5c425fef27ebc..01ad3aefffa3b 100644 --- a/compiler/rustc_middle/src/mir/mod.rs +++ b/compiler/rustc_middle/src/mir/mod.rs @@ -250,6 +250,9 @@ pub struct CoroutineInfo<'tcx> { /// The yield type of the function, if it is a coroutine. pub yield_ty: Option>, + /// The resume type of the function, if it is a coroutine. + pub resume_ty: Option>, + /// Coroutine drop glue. pub coroutine_drop: Option>, @@ -385,6 +388,7 @@ impl<'tcx> Body<'tcx> { coroutine: coroutine_kind.map(|coroutine_kind| { Box::new(CoroutineInfo { yield_ty: None, + resume_ty: None, coroutine_drop: None, coroutine_layout: None, coroutine_kind, @@ -551,6 +555,11 @@ impl<'tcx> Body<'tcx> { self.coroutine.as_ref().and_then(|coroutine| coroutine.yield_ty) } + #[inline] + pub fn resume_ty(&self) -> Option> { + self.coroutine.as_ref().and_then(|coroutine| coroutine.resume_ty) + } + #[inline] pub fn coroutine_layout(&self) -> Option<&CoroutineLayout<'tcx>> { self.coroutine.as_ref().and_then(|coroutine| coroutine.coroutine_layout.as_ref()) diff --git a/compiler/rustc_middle/src/mir/visit.rs b/compiler/rustc_middle/src/mir/visit.rs index 132ecf91af187..2ccf5a9f6f7ad 100644 --- a/compiler/rustc_middle/src/mir/visit.rs +++ b/compiler/rustc_middle/src/mir/visit.rs @@ -996,6 +996,12 @@ macro_rules! super_body { TyContext::YieldTy(SourceInfo::outermost(span)) ); } + if let Some(resume_ty) = $(& $mutability)? gen.resume_ty { + $self.visit_ty( + resume_ty, + TyContext::ResumeTy(SourceInfo::outermost(span)) + ); + } } for (bb, data) in basic_blocks_iter!($body, $($mutability, $invalidate)?) { @@ -1244,6 +1250,8 @@ pub enum TyContext { YieldTy(SourceInfo), + ResumeTy(SourceInfo), + /// A type found at some location. Location(Location), } diff --git a/compiler/rustc_mir_build/src/build/mod.rs b/compiler/rustc_mir_build/src/build/mod.rs index e0199fb876717..c4cade839478c 100644 --- a/compiler/rustc_mir_build/src/build/mod.rs +++ b/compiler/rustc_mir_build/src/build/mod.rs @@ -488,7 +488,7 @@ fn construct_fn<'tcx>( let arguments = &thir.params; - let (yield_ty, return_ty) = if coroutine_kind.is_some() { + let (resume_ty, yield_ty, return_ty) = if coroutine_kind.is_some() { let coroutine_ty = arguments[thir::UPVAR_ENV_PARAM].ty; let coroutine_sig = match coroutine_ty.kind() { ty::Coroutine(_, gen_args, ..) => gen_args.as_coroutine().sig(), @@ -496,9 +496,9 @@ fn construct_fn<'tcx>( span_bug!(span, "coroutine w/o coroutine type: {:?}", coroutine_ty) } }; - (Some(coroutine_sig.yield_ty), coroutine_sig.return_ty) + (Some(coroutine_sig.resume_ty), Some(coroutine_sig.yield_ty), coroutine_sig.return_ty) } else { - (None, fn_sig.output()) + (None, None, fn_sig.output()) }; if let Some(custom_mir_attr) = @@ -562,9 +562,12 @@ fn construct_fn<'tcx>( } else { None }; - if yield_ty.is_some() { + + if coroutine_kind.is_some() { body.coroutine.as_mut().unwrap().yield_ty = yield_ty; + body.coroutine.as_mut().unwrap().resume_ty = resume_ty; } + body } @@ -631,18 +634,18 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) - let hir_id = tcx.local_def_id_to_hir_id(def_id); let coroutine_kind = tcx.coroutine_kind(def_id); - let (inputs, output, yield_ty) = match tcx.def_kind(def_id) { + let (inputs, output, resume_ty, yield_ty) = match tcx.def_kind(def_id) { DefKind::Const | DefKind::AssocConst | DefKind::AnonConst | DefKind::InlineConst - | DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None), + | DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None, None), DefKind::Ctor(..) | DefKind::Fn | DefKind::AssocFn => { let sig = tcx.liberate_late_bound_regions( def_id.to_def_id(), tcx.fn_sig(def_id).instantiate_identity(), ); - (sig.inputs().to_vec(), sig.output(), None) + (sig.inputs().to_vec(), sig.output(), None, None) } DefKind::Closure if coroutine_kind.is_some() => { let coroutine_ty = tcx.type_of(def_id).instantiate_identity(); @@ -650,9 +653,10 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) - bug!("expected type of coroutine-like closure to be a coroutine") }; let args = args.as_coroutine(); + let resume_ty = args.resume_ty(); let yield_ty = args.yield_ty(); let return_ty = args.return_ty(); - (vec![coroutine_ty, args.resume_ty()], return_ty, Some(yield_ty)) + (vec![coroutine_ty, args.resume_ty()], return_ty, Some(resume_ty), Some(yield_ty)) } DefKind::Closure => { let closure_ty = tcx.type_of(def_id).instantiate_identity(); @@ -666,7 +670,7 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) - ty::ClosureKind::FnMut => Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty), ty::ClosureKind::FnOnce => closure_ty, }; - ([self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), sig.output(), None) + ([self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), sig.output(), None, None) } dk => bug!("{:?} is not a body: {:?}", def_id, dk), }; @@ -705,7 +709,10 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) - Some(guar), ); - body.coroutine.as_mut().map(|gen| gen.yield_ty = yield_ty); + body.coroutine.as_mut().map(|gen| { + gen.yield_ty = yield_ty; + gen.resume_ty = resume_ty; + }); body } diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index ce1a36cf67021..33e305497b505 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -1733,6 +1733,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { } body.coroutine.as_mut().unwrap().yield_ty = None; + body.coroutine.as_mut().unwrap().resume_ty = None; body.coroutine.as_mut().unwrap().coroutine_layout = Some(layout); // Insert `drop(coroutine_struct)` which is used to drop upvars for coroutines in diff --git a/tests/ui/coroutine/check-resume-ty-lifetimes-2.rs b/tests/ui/coroutine/check-resume-ty-lifetimes-2.rs new file mode 100644 index 0000000000000..a316c50e86732 --- /dev/null +++ b/tests/ui/coroutine/check-resume-ty-lifetimes-2.rs @@ -0,0 +1,35 @@ +#![feature(coroutine_trait)] +#![feature(coroutines)] + +use std::ops::Coroutine; + +struct Contravariant<'a>(fn(&'a ())); +struct Covariant<'a>(fn() -> &'a ()); + +fn bad1<'short, 'long: 'short>() -> impl Coroutine> { + |_: Covariant<'short>| { + let a: Covariant<'long> = yield (); + //~^ ERROR lifetime may not live long enough + } +} + +fn bad2<'short, 'long: 'short>() -> impl Coroutine> { + |_: Contravariant<'long>| { + let a: Contravariant<'short> = yield (); + //~^ ERROR lifetime may not live long enough + } +} + +fn good1<'short, 'long: 'short>() -> impl Coroutine> { + |_: Covariant<'long>| { + let a: Covariant<'short> = yield (); + } +} + +fn good2<'short, 'long: 'short>() -> impl Coroutine> { + |_: Contravariant<'short>| { + let a: Contravariant<'long> = yield (); + } +} + +fn main() {} diff --git a/tests/ui/coroutine/check-resume-ty-lifetimes-2.stderr b/tests/ui/coroutine/check-resume-ty-lifetimes-2.stderr new file mode 100644 index 0000000000000..e0cbca2dd5267 --- /dev/null +++ b/tests/ui/coroutine/check-resume-ty-lifetimes-2.stderr @@ -0,0 +1,36 @@ +error: lifetime may not live long enough + --> $DIR/check-resume-ty-lifetimes-2.rs:11:16 + | +LL | fn bad1<'short, 'long: 'short>() -> impl Coroutine> { + | ------ ----- lifetime `'long` defined here + | | + | lifetime `'short` defined here +LL | |_: Covariant<'short>| { +LL | let a: Covariant<'long> = yield (); + | ^^^^^^^^^^^^^^^^ type annotation requires that `'short` must outlive `'long` + | + = help: consider adding the following bound: `'short: 'long` +help: consider adding 'move' keyword before the nested closure + | +LL | move |_: Covariant<'short>| { + | ++++ + +error: lifetime may not live long enough + --> $DIR/check-resume-ty-lifetimes-2.rs:18:40 + | +LL | fn bad2<'short, 'long: 'short>() -> impl Coroutine> { + | ------ ----- lifetime `'long` defined here + | | + | lifetime `'short` defined here +LL | |_: Contravariant<'long>| { +LL | let a: Contravariant<'short> = yield (); + | ^^^^^^^^ yielding this value requires that `'short` must outlive `'long` + | + = help: consider adding the following bound: `'short: 'long` +help: consider adding 'move' keyword before the nested closure + | +LL | move |_: Contravariant<'long>| { + | ++++ + +error: aborting due to 2 previous errors + diff --git a/tests/ui/coroutine/check-resume-ty-lifetimes.rs b/tests/ui/coroutine/check-resume-ty-lifetimes.rs new file mode 100644 index 0000000000000..add0b5080a8a8 --- /dev/null +++ b/tests/ui/coroutine/check-resume-ty-lifetimes.rs @@ -0,0 +1,27 @@ +#![feature(coroutine_trait)] +#![feature(coroutines)] +#![allow(unused)] + +use std::ops::Coroutine; +use std::ops::CoroutineState; +use std::pin::pin; + +fn mk_static(s: &str) -> &'static str { + let mut storage: Option<&'static str> = None; + + let mut coroutine = pin!(|_: &str| { + let x: &'static str = yield (); + //~^ ERROR lifetime may not live long enough + storage = Some(x); + }); + + coroutine.as_mut().resume(s); + coroutine.as_mut().resume(s); + + storage.unwrap() +} + +fn main() { + let s = mk_static(&String::from("hello, world")); + println!("{s}"); +} diff --git a/tests/ui/coroutine/check-resume-ty-lifetimes.stderr b/tests/ui/coroutine/check-resume-ty-lifetimes.stderr new file mode 100644 index 0000000000000..f373aa778a82c --- /dev/null +++ b/tests/ui/coroutine/check-resume-ty-lifetimes.stderr @@ -0,0 +1,11 @@ +error: lifetime may not live long enough + --> $DIR/check-resume-ty-lifetimes.rs:13:16 + | +LL | fn mk_static(s: &str) -> &'static str { + | - let's call the lifetime of this reference `'1` +... +LL | let x: &'static str = yield (); + | ^^^^^^^^^^^^ type annotation requires that `'1` must outlive `'static` + +error: aborting due to 1 previous error +