Skip to content

Commit

Permalink
Make gen blocks implement the Iterator trait
Browse files Browse the repository at this point in the history
  • Loading branch information
oli-obk committed Oct 23, 2023
1 parent f5fb745 commit ba499c7
Show file tree
Hide file tree
Showing 17 changed files with 280 additions and 6 deletions.
1 change: 1 addition & 0 deletions compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
},
)
}
Some(hir::CoroutineKind::Gen(hir::AsyncCoroutineKind::Fn)) => todo!(),

_ => astconv.ty_infer(None, decl.output.span()),
},
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_middle/src/traits/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ pub enum SelectionCandidate<'tcx> {
/// generated for an async construct.
FutureCandidate,

/// Implementation of an `Iterator` trait by one of the generator types
/// generated for a gen construct.
IteratorCandidate,

/// Implementation of a `Fn`-family trait by one of the anonymous
/// types generated for a fn pointer type (e.g., `fn(int) -> int`)
FnPointerCandidate {
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,11 @@ impl<'tcx> TyCtxt<'tcx> {
matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::Async(_)))
}

/// Returns `true` if the node pointed to by `def_id` is a coroutine for a gen construct.
pub fn coroutine_is_gen(self, def_id: DefId) -> bool {
matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::Gen(_)))
}

pub fn stability(self) -> &'tcx stability::Index {
self.stability_index(())
}
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ symbols! {
IpAddr,
IrTyKind,
Is,
Item,
ItemContext,
IterEmpty,
IterOnce,
Expand Down Expand Up @@ -911,6 +912,7 @@ symbols! {
iter,
iter_mut,
iter_repeat,
iterator,
iterator_collect_fn,
kcfi,
keyword,
Expand Down
12 changes: 11 additions & 1 deletion compiler/rustc_trait_selection/src/solve/assembly/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,15 @@ pub(super) trait GoalKind<'tcx>:
goal: Goal<'tcx, Self>,
) -> QueryResult<'tcx>;

/// A coroutine (that doesn't come from an `async` desugaring) is known to
/// A coroutine (that comes from an `gen` desugaring) is known to implement
/// `Iterator<Item = O>`, where `O` is given by the generator's yield type
/// that was computed during type-checking.
fn consider_builtin_iterator_candidate(
ecx: &mut EvalCtxt<'_, 'tcx>,
goal: Goal<'tcx, Self>,
) -> QueryResult<'tcx>;

/// A coroutine (that doesn't come from an `async` or `gen` desugaring) is known to
/// implement `Coroutine<R, Yield = Y, Return = O>`, given the resume, yield,
/// and return types of the coroutine computed during type-checking.
fn consider_builtin_coroutine_candidate(
Expand Down Expand Up @@ -552,6 +560,8 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
G::consider_builtin_pointee_candidate(self, goal)
} else if lang_items.future_trait() == Some(trait_def_id) {
G::consider_builtin_future_candidate(self, goal)
} else if lang_items.iterator_trait() == Some(trait_def_id) {
G::consider_builtin_iterator_candidate(self, goal)
} else if lang_items.gen_trait() == Some(trait_def_id) {
G::consider_builtin_coroutine_candidate(self, goal)
} else if lang_items.discriminant_kind_trait() == Some(trait_def_id) {
Expand Down
35 changes: 33 additions & 2 deletions compiler/rustc_trait_selection/src/solve/project_goals/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,37 @@ impl<'tcx> assembly::GoalKind<'tcx> for ProjectionPredicate<'tcx> {
)
}

fn consider_builtin_iterator_candidate(
ecx: &mut EvalCtxt<'_, 'tcx>,
goal: Goal<'tcx, Self>,
) -> QueryResult<'tcx> {
let self_ty = goal.predicate.self_ty();
let ty::Coroutine(def_id, args, _) = *self_ty.kind() else {
return Err(NoSolution);
};

// Generators are not Iterators unless they come from `gen` desugaring
let tcx = ecx.tcx();
if !tcx.coroutine_is_gen(def_id) {
return Err(NoSolution);
}

let term = args.as_coroutine().yield_ty().into();

Self::consider_implied_clause(
ecx,
goal,
ty::ProjectionPredicate {
projection_ty: ty::AliasTy::new(ecx.tcx(), goal.predicate.def_id(), [self_ty]),
term,
}
.to_predicate(tcx),
// Technically, we need to check that the iterator type is Sized,
// but that's already proven by the generator being WF.
[],
)
}

fn consider_builtin_coroutine_candidate(
ecx: &mut EvalCtxt<'_, 'tcx>,
goal: Goal<'tcx, Self>,
Expand All @@ -496,7 +527,7 @@ impl<'tcx> assembly::GoalKind<'tcx> for ProjectionPredicate<'tcx> {

// `async`-desugared coroutines do not implement the coroutine trait
let tcx = ecx.tcx();
if tcx.coroutine_is_async(def_id) {
if tcx.coroutine_is_async(def_id) || tcx.coroutine_is_gen(def_id) {
return Err(NoSolution);
}

Expand All @@ -523,7 +554,7 @@ impl<'tcx> assembly::GoalKind<'tcx> for ProjectionPredicate<'tcx> {
term,
}
.to_predicate(tcx),
// Technically, we need to check that the future type is Sized,
// Technically, we need to check that the coroutine type is Sized,
// but that's already proven by the coroutine being WF.
[],
)
Expand Down
26 changes: 25 additions & 1 deletion compiler/rustc_trait_selection/src/solve/trait_goals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,30 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> {
ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
}

fn consider_builtin_iterator_candidate(
ecx: &mut EvalCtxt<'_, 'tcx>,
goal: Goal<'tcx, Self>,
) -> QueryResult<'tcx> {
if goal.predicate.polarity != ty::ImplPolarity::Positive {
return Err(NoSolution);
}

let ty::Coroutine(def_id, _, _) = *goal.predicate.self_ty().kind() else {
return Err(NoSolution);
};

// Coroutines are not iterators unless they come from `gen` desugaring
let tcx = ecx.tcx();
if !tcx.coroutine_is_gen(def_id) {
return Err(NoSolution);
}

// Gen coroutines unconditionally implement `Iterator`
// Technically, we need to check that the iterator output type is Sized,
// but that's already proven by the coroutines being WF.
ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
}

fn consider_builtin_coroutine_candidate(
ecx: &mut EvalCtxt<'_, 'tcx>,
goal: Goal<'tcx, Self>,
Expand All @@ -350,7 +374,7 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> {

// `async`-desugared coroutines do not implement the coroutine trait
let tcx = ecx.tcx();
if tcx.coroutine_is_async(def_id) {
if tcx.coroutine_is_async(def_id) || tcx.coroutine_is_gen(def_id) {
return Err(NoSolution);
}

Expand Down
48 changes: 47 additions & 1 deletion compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1789,7 +1789,7 @@ fn assemble_candidates_from_impls<'cx, 'tcx>(
let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty());

let lang_items = selcx.tcx().lang_items();
if [lang_items.gen_trait(), lang_items.future_trait()].contains(&Some(trait_ref.def_id))
if [lang_items.gen_trait(), lang_items.future_trait(), lang_items.iterator_trait()].contains(&Some(trait_ref.def_id))
|| selcx.tcx().fn_trait_kind_from_def_id(trait_ref.def_id).is_some()
{
true
Expand Down Expand Up @@ -2006,6 +2006,8 @@ fn confirm_select_candidate<'cx, 'tcx>(
confirm_coroutine_candidate(selcx, obligation, data)
} else if lang_items.future_trait() == Some(trait_def_id) {
confirm_future_candidate(selcx, obligation, data)
} else if lang_items.iterator_trait() == Some(trait_def_id) {
confirm_iterator_candidate(selcx, obligation, data)
} else if selcx.tcx().fn_trait_kind_from_def_id(trait_def_id).is_some() {
if obligation.predicate.self_ty().is_closure() {
confirm_closure_candidate(selcx, obligation, data)
Expand Down Expand Up @@ -2126,6 +2128,50 @@ fn confirm_future_candidate<'cx, 'tcx>(
.with_addl_obligations(obligations)
}

fn confirm_iterator_candidate<'cx, 'tcx>(
selcx: &mut SelectionContext<'cx, 'tcx>,
obligation: &ProjectionTyObligation<'tcx>,
nested: Vec<PredicateObligation<'tcx>>,
) -> Progress<'tcx> {
let ty::Coroutine(_, args, _) =
selcx.infcx.shallow_resolve(obligation.predicate.self_ty()).kind()
else {
unreachable!()
};
let gen_sig = args.as_coroutine().poly_sig();
let Normalized { value: gen_sig, obligations } = normalize_with_depth(
selcx,
obligation.param_env,
obligation.cause.clone(),
obligation.recursion_depth + 1,
gen_sig,
);

debug!(?obligation, ?gen_sig, ?obligations, "confirm_future_candidate");

let tcx = selcx.tcx();
let iter_def_id = tcx.require_lang_item(LangItem::Iterator, None);

let predicate = super::util::iterator_trait_ref_and_outputs(
tcx,
iter_def_id,
obligation.predicate.self_ty(),
gen_sig,
)
.map_bound(|(trait_ref, yield_ty)| {
debug_assert_eq!(tcx.associated_item(obligation.predicate.def_id).name, sym::Item);

ty::ProjectionPredicate {
projection_ty: ty::AliasTy::new(tcx, obligation.predicate.def_id, trait_ref.args),
term: yield_ty.into(),
}
});

confirm_param_env_candidate(selcx, obligation, predicate, false)
.with_addl_obligations(nested)
.with_addl_obligations(obligations)
}

fn confirm_builtin_candidate<'cx, 'tcx>(
selcx: &mut SelectionContext<'cx, 'tcx>,
obligation: &ProjectionTyObligation<'tcx>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
self.assemble_coroutine_candidates(obligation, &mut candidates);
} else if lang_items.future_trait() == Some(def_id) {
self.assemble_future_candidates(obligation, &mut candidates);
} else if lang_items.iterator_trait() == Some(def_id) {
self.assemble_iterator_candidates(obligation, &mut candidates);
}

self.assemble_closure_candidates(obligation, &mut candidates);
Expand Down Expand Up @@ -213,7 +215,9 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
match self_ty.kind() {
// async constructs get lowered to a special kind of coroutine that
// should *not* `impl Coroutine`.
ty::Coroutine(did, ..) if !self.tcx().coroutine_is_async(*did) => {
ty::Coroutine(did, ..)
if !self.tcx().coroutine_is_async(*did) && !self.tcx().coroutine_is_gen(*did) =>
{
debug!(?self_ty, ?obligation, "assemble_coroutine_candidates",);

candidates.vec.push(CoroutineCandidate);
Expand Down Expand Up @@ -243,6 +247,23 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
}
}

fn assemble_iterator_candidates(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
candidates: &mut SelectionCandidateSet<'tcx>,
) {
let self_ty = obligation.self_ty().skip_binder();
if let ty::Coroutine(did, ..) = self_ty.kind() {
// gen constructs get lowered to a special kind of coroutine that
// should directly `impl Iterator`.
if self.tcx().coroutine_is_gen(*did) {
debug!(?self_ty, ?obligation, "assemble_iterator_candidates",);

candidates.vec.push(IteratorCandidate);
}
}
}

/// Checks for the artificial impl that the compiler will create for an obligation like `X :
/// FnMut<..>` where `X` is a closure type.
///
Expand Down
35 changes: 35 additions & 0 deletions compiler/rustc_trait_selection/src/traits/select/confirmation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
ImplSource::Builtin(BuiltinImplSource::Misc, vtable_future)
}

IteratorCandidate => {
let vtable_iterator = self.confirm_iterator_candidate(obligation)?;
ImplSource::Builtin(BuiltinImplSource::Misc, vtable_iterator)
}

FnPointerCandidate { is_const } => {
let data = self.confirm_fn_pointer_candidate(obligation, is_const)?;
ImplSource::Builtin(BuiltinImplSource::Misc, data)
Expand Down Expand Up @@ -780,6 +785,36 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
Ok(nested)
}

fn confirm_iterator_candidate(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on generator types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let ty::Coroutine(generator_def_id, args, _) = *self_ty.kind() else {
bug!("closure candidate for non-closure {:?}", obligation);
};

debug!(?obligation, ?generator_def_id, ?args, "confirm_iterator_candidate");

let gen_sig = args.as_coroutine().poly_sig();

let trait_ref = super::util::iterator_trait_ref_and_outputs(
self.tcx(),
obligation.predicate.def_id(),
obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
gen_sig,
)
.map_bound(|(trait_ref, ..)| trait_ref);

let nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;
debug!(?trait_ref, ?nested, "iterator candidate obligations");

Ok(nested)
}

#[instrument(skip(self), level = "debug")]
fn confirm_closure_candidate(
&mut self,
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_trait_selection/src/traits/select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1888,6 +1888,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
| ClosureCandidate { .. }
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
| FnPointerCandidate { .. }
| BuiltinObjectCandidate
| BuiltinUnsizeCandidate
Expand Down Expand Up @@ -1916,6 +1917,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
| ClosureCandidate { .. }
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
| FnPointerCandidate { .. }
| BuiltinObjectCandidate
| BuiltinUnsizeCandidate
Expand Down Expand Up @@ -1950,6 +1952,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
| ClosureCandidate { .. }
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
| FnPointerCandidate { .. }
| BuiltinObjectCandidate
| BuiltinUnsizeCandidate
Expand All @@ -1964,6 +1967,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
| ClosureCandidate { .. }
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
| FnPointerCandidate { .. }
| BuiltinObjectCandidate
| BuiltinUnsizeCandidate
Expand Down Expand Up @@ -2070,6 +2074,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
| ClosureCandidate { .. }
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
| FnPointerCandidate { .. }
| BuiltinObjectCandidate
| BuiltinUnsizeCandidate
Expand All @@ -2080,6 +2085,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
| ClosureCandidate { .. }
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
| FnPointerCandidate { .. }
| BuiltinObjectCandidate
| BuiltinUnsizeCandidate
Expand Down
11 changes: 11 additions & 0 deletions compiler/rustc_trait_selection/src/traits/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,17 @@ pub fn future_trait_ref_and_outputs<'tcx>(
sig.map_bound(|sig| (trait_ref, sig.return_ty))
}

pub fn iterator_trait_ref_and_outputs<'tcx>(
tcx: TyCtxt<'tcx>,
iterator_def_id: DefId,
self_ty: Ty<'tcx>,
sig: ty::PolyGenSig<'tcx>,
) -> ty::Binder<'tcx, (ty::TraitRef<'tcx>, Ty<'tcx>)> {
assert!(!self_ty.has_escaping_bound_vars());
let trait_ref = ty::TraitRef::new(tcx, iterator_def_id, [self_ty]);
sig.map_bound(|sig| (trait_ref, sig.yield_ty))
}

pub fn impl_item_is_final(tcx: TyCtxt<'_>, assoc_item: &ty::AssocItem) -> bool {
assoc_item.defaultness(tcx).is_final()
&& tcx.defaultness(assoc_item.container_id(tcx)).is_final()
Expand Down
Loading

0 comments on commit ba499c7

Please sign in to comment.