From ae4c5c891ec55e2bf78615650daded9124552ca0 Mon Sep 17 00:00:00 2001 From: Shoyu Vanilla Date: Fri, 22 Mar 2024 00:48:36 +0900 Subject: [PATCH] Implement `FusedIterator` for `gen` block --- compiler/rustc_hir/src/lang_items.rs | 1 + compiler/rustc_span/src/symbol.rs | 2 ++ .../src/solve/assembly/mod.rs | 9 +++++ .../src/solve/normalizes_to/mod.rs | 7 ++++ .../src/solve/trait_goals.rs | 22 +++++++++++++ .../src/traits/select/candidate_assembly.rs | 33 +++++++++++++++---- .../src/traits/select/confirmation.rs | 2 ++ .../src/traits/select/mod.rs | 14 ++++++++ library/core/src/iter/traits/marker.rs | 1 + tests/ui/coroutine/gen_block_is_fused_iter.rs | 21 ++++++++++++ 10 files changed, 105 insertions(+), 7 deletions(-) create mode 100644 tests/ui/coroutine/gen_block_is_fused_iter.rs diff --git a/compiler/rustc_hir/src/lang_items.rs b/compiler/rustc_hir/src/lang_items.rs index 5118bf5c3b7ab..dbf86f5cf747d 100644 --- a/compiler/rustc_hir/src/lang_items.rs +++ b/compiler/rustc_hir/src/lang_items.rs @@ -214,6 +214,7 @@ language_item_table! { FnOnceOutput, sym::fn_once_output, fn_once_output, Target::AssocTy, GenericRequirement::None; Iterator, sym::iterator, iterator_trait, Target::Trait, GenericRequirement::Exact(0); + FusedIterator, sym::fused_iterator, fused_iterator_trait, Target::Trait, GenericRequirement::Exact(0); Future, sym::future_trait, future_trait, Target::Trait, GenericRequirement::Exact(0); AsyncIterator, sym::async_iterator, async_iterator_trait, Target::Trait, GenericRequirement::Exact(0); diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 8b911a41a112f..c28c577d78014 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -207,6 +207,7 @@ symbols! { FromResidual, FsOpenOptions, FsPermissions, + FusedIterator, Future, FutureOutput, GlobalAlloc, @@ -885,6 +886,7 @@ symbols! { fsub_algebraic, fsub_fast, fundamental, + fused_iterator, future, future_trait, gdb_script_file, diff --git a/compiler/rustc_trait_selection/src/solve/assembly/mod.rs b/compiler/rustc_trait_selection/src/solve/assembly/mod.rs index 9f33dce2a6dfe..d92bae2528fbc 100644 --- a/compiler/rustc_trait_selection/src/solve/assembly/mod.rs +++ b/compiler/rustc_trait_selection/src/solve/assembly/mod.rs @@ -215,6 +215,13 @@ pub(super) trait GoalKind<'tcx>: goal: Goal<'tcx, Self>, ) -> QueryResult<'tcx>; + /// A coroutine (that comes from a `gen` desugaring) is known to implement + /// `FusedIterator` + fn consider_builtin_fused_iterator_candidate( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + ) -> QueryResult<'tcx>; + fn consider_builtin_async_iterator_candidate( ecx: &mut EvalCtxt<'_, 'tcx>, goal: Goal<'tcx, Self>, @@ -497,6 +504,8 @@ impl<'tcx> EvalCtxt<'_, 'tcx> { G::consider_builtin_future_candidate(self, goal) } else if lang_items.iterator_trait() == Some(trait_def_id) { G::consider_builtin_iterator_candidate(self, goal) + } else if lang_items.fused_iterator_trait() == Some(trait_def_id) { + G::consider_builtin_fused_iterator_candidate(self, goal) } else if lang_items.async_iterator_trait() == Some(trait_def_id) { G::consider_builtin_async_iterator_candidate(self, goal) } else if lang_items.coroutine_trait() == Some(trait_def_id) { diff --git a/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs b/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs index 85bb6338daff9..6668889323512 100644 --- a/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs +++ b/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs @@ -647,6 +647,13 @@ impl<'tcx> assembly::GoalKind<'tcx> for NormalizesTo<'tcx> { ) } + fn consider_builtin_fused_iterator_candidate( + _ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + ) -> QueryResult<'tcx> { + bug!("`FusedIterator` does not have an associated type: {:?}", goal); + } + fn consider_builtin_async_iterator_candidate( ecx: &mut EvalCtxt<'_, 'tcx>, goal: Goal<'tcx, Self>, diff --git a/compiler/rustc_trait_selection/src/solve/trait_goals.rs b/compiler/rustc_trait_selection/src/solve/trait_goals.rs index c252ad76dfe1d..184ba31f19d20 100644 --- a/compiler/rustc_trait_selection/src/solve/trait_goals.rs +++ b/compiler/rustc_trait_selection/src/solve/trait_goals.rs @@ -456,6 +456,28 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> { ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) } + fn consider_builtin_fused_iterator_candidate( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + ) -> QueryResult<'tcx> { + if goal.predicate.polarity != ty::ImplPolarity::Positive { + return Err(NoSolution); + } + + let ty::Coroutine(def_id, _) = *goal.predicate.self_ty().kind() else { + return Err(NoSolution); + }; + + // Coroutines are not iterators unless they come from `gen` desugaring + let tcx = ecx.tcx(); + if !tcx.coroutine_is_gen(def_id) { + return Err(NoSolution); + } + + // Gen coroutines unconditionally implement `FusedIterator` + ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) + } + fn consider_builtin_async_iterator_candidate( ecx: &mut EvalCtxt<'_, 'tcx>, goal: Goal<'tcx, Self>, diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs index 49091e53be713..9fb4577fb2169 100644 --- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs +++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs @@ -118,6 +118,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { self.assemble_future_candidates(obligation, &mut candidates); } else if lang_items.iterator_trait() == Some(def_id) { self.assemble_iterator_candidates(obligation, &mut candidates); + } else if lang_items.fused_iterator_trait() == Some(def_id) { + self.assemble_fused_iterator_candidates(obligation, &mut candidates); } else if lang_items.async_iterator_trait() == Some(def_id) { self.assemble_async_iterator_candidates(obligation, &mut candidates); } else if lang_items.async_fn_kind_helper() == Some(def_id) { @@ -302,14 +304,31 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { candidates: &mut SelectionCandidateSet<'tcx>, ) { let self_ty = obligation.self_ty().skip_binder(); - if let ty::Coroutine(did, ..) = self_ty.kind() { - // gen constructs get lowered to a special kind of coroutine that - // should directly `impl Iterator`. - if self.tcx().coroutine_is_gen(*did) { - debug!(?self_ty, ?obligation, "assemble_iterator_candidates",); + // gen constructs get lowered to a special kind of coroutine that + // should directly `impl Iterator`. + if let ty::Coroutine(did, ..) = self_ty.kind() + && self.tcx().coroutine_is_gen(*did) + { + debug!(?self_ty, ?obligation, "assemble_iterator_candidates",); - candidates.vec.push(IteratorCandidate); - } + candidates.vec.push(IteratorCandidate); + } + } + + fn assemble_fused_iterator_candidates( + &mut self, + obligation: &PolyTraitObligation<'tcx>, + candidates: &mut SelectionCandidateSet<'tcx>, + ) { + let self_ty = obligation.self_ty().skip_binder(); + // gen constructs get lowered to a special kind of coroutine that + // should directly `impl FusedIterator`. + if let ty::Coroutine(did, ..) = self_ty.kind() + && self.tcx().coroutine_is_gen(*did) + { + debug!(?self_ty, ?obligation, "assemble_fused_iterator_candidates",); + + candidates.vec.push(BuiltinCandidate { has_nested: false }); } } diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs index 51fc223a5d1b3..aeac2ad77d70f 100644 --- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs +++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs @@ -267,6 +267,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { self.copy_clone_conditions(obligation) } else if Some(trait_def) == lang_items.clone_trait() { self.copy_clone_conditions(obligation) + } else if Some(trait_def) == lang_items.fused_iterator_trait() { + self.fused_iterator_conditions(obligation) } else { bug!("unexpected builtin trait {:?}", trait_def) }; diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs index 53aadfb8a44d8..adbc7d12a648d 100644 --- a/compiler/rustc_trait_selection/src/traits/select/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs @@ -2259,6 +2259,20 @@ impl<'tcx> SelectionContext<'_, 'tcx> { } } + fn fused_iterator_conditions( + &mut self, + obligation: &PolyTraitObligation<'tcx>, + ) -> BuiltinImplConditions<'tcx> { + let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder()); + if let ty::Coroutine(did, ..) = *self_ty.kind() + && self.tcx().coroutine_is_gen(did) + { + BuiltinImplConditions::Where(ty::Binder::dummy(Vec::new())) + } else { + BuiltinImplConditions::None + } + } + /// For default impls, we need to break apart a type into its /// "constituent types" -- meaning, the types that it contains. /// diff --git a/library/core/src/iter/traits/marker.rs b/library/core/src/iter/traits/marker.rs index 8bdbca120d7f9..ad4d63d83b5be 100644 --- a/library/core/src/iter/traits/marker.rs +++ b/library/core/src/iter/traits/marker.rs @@ -28,6 +28,7 @@ pub unsafe trait TrustedFused {} #[rustc_unsafe_specialization_marker] // FIXME: this should be a #[marker] and have another blanket impl for T: TrustedFused // but that ICEs iter::Fuse specializations. +#[cfg_attr(not(bootstrap), lang = "fused_iterator")] pub trait FusedIterator: Iterator {} #[stable(feature = "fused", since = "1.26.0")] diff --git a/tests/ui/coroutine/gen_block_is_fused_iter.rs b/tests/ui/coroutine/gen_block_is_fused_iter.rs new file mode 100644 index 0000000000000..f3e19a7f54f03 --- /dev/null +++ b/tests/ui/coroutine/gen_block_is_fused_iter.rs @@ -0,0 +1,21 @@ +//@ revisions: next old +//@compile-flags: --edition 2024 -Zunstable-options +//@[next] compile-flags: -Znext-solver +//@ check-pass +#![feature(gen_blocks)] + +use std::iter::FusedIterator; + +fn foo() -> impl FusedIterator { + gen { yield 42 } +} + +fn bar() -> impl FusedIterator { + gen { yield 42 } +} + +fn baz() -> impl FusedIterator + Iterator { + gen { yield 42 } +} + +fn main() {}