diff --git a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs index 62e16d445c63f..47bd24f1e141c 100644 --- a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs +++ b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs @@ -1469,27 +1469,31 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> { let hir = tcx.hir(); let Some(body_id) = tcx.hir_node(self.mir_hir_id()).body_id() else { return }; struct FindUselessClone<'hir> { + tcx: TyCtxt<'hir>, + def_id: DefId, pub clones: Vec<&'hir hir::Expr<'hir>>, } impl<'hir> FindUselessClone<'hir> { - pub fn new() -> Self { - Self { clones: vec![] } + pub fn new(tcx: TyCtxt<'hir>, def_id: DefId) -> Self { + Self { tcx, def_id, clones: vec![] } } } impl<'v> Visitor<'v> for FindUselessClone<'v> { fn visit_expr(&mut self, ex: &'v hir::Expr<'v>) { - // FIXME: use `lookup_method_for_diagnostic`? if let hir::ExprKind::MethodCall(segment, _rcvr, args, _span) = ex.kind && segment.ident.name == sym::clone && args.len() == 0 + && let Some(def_id) = self.def_id.as_local() + && let Some(method) = self.tcx.lookup_method_for_diagnostic((def_id, ex.hir_id)) + && Some(self.tcx.parent(method)) == self.tcx.lang_items().clone_trait() { self.clones.push(ex); } hir::intravisit::walk_expr(self, ex); } } - let mut expr_finder = FindUselessClone::new(); + let mut expr_finder = FindUselessClone::new(tcx, self.mir_def_id().into()); let body = hir.body(body_id).value; expr_finder.visit_expr(body); diff --git a/compiler/rustc_borrowck/src/diagnostics/region_errors.rs b/compiler/rustc_borrowck/src/diagnostics/region_errors.rs index c92fccc959fd5..304d41d694175 100644 --- a/compiler/rustc_borrowck/src/diagnostics/region_errors.rs +++ b/compiler/rustc_borrowck/src/diagnostics/region_errors.rs @@ -26,6 +26,9 @@ use rustc_middle::ty::{self, RegionVid, Ty}; use rustc_middle::ty::{Region, TyCtxt}; use rustc_span::symbol::{kw, Ident}; use rustc_span::Span; +use rustc_trait_selection::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind}; +use rustc_trait_selection::infer::InferCtxtExt; +use rustc_trait_selection::traits::{Obligation, ObligationCtxt}; use crate::borrowck_errors; use crate::session_diagnostics::{ @@ -810,6 +813,7 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, 'tcx> { self.add_static_impl_trait_suggestion(&mut diag, *fr, fr_name, *outlived_fr); self.suggest_adding_lifetime_params(&mut diag, *fr, *outlived_fr); self.suggest_move_on_borrowing_closure(&mut diag); + self.suggest_deref_closure_value(&mut diag); diag } @@ -1039,6 +1043,147 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, 'tcx> { suggest_adding_lifetime_params(self.infcx.tcx, sub, ty_sup, ty_sub, diag); } + #[allow(rustc::diagnostic_outside_of_impl)] + #[allow(rustc::untranslatable_diagnostic)] // FIXME: make this translatable + /// When encountering a lifetime error caused by the return type of a closure, check the + /// corresponding trait bound and see if dereferencing the closure return value would satisfy + /// them. If so, we produce a structured suggestion. + fn suggest_deref_closure_value(&self, diag: &mut Diag<'_>) { + let tcx = self.infcx.tcx; + let map = tcx.hir(); + + // Get the closure return value and type. + let body_id = map.body_owned_by(self.mir_def_id()); + let body = &map.body(body_id); + let value = &body.value.peel_blocks(); + let hir::Node::Expr(closure_expr) = tcx.hir_node_by_def_id(self.mir_def_id()) else { + return; + }; + let fn_call_id = tcx.parent_hir_id(self.mir_hir_id()); + let hir::Node::Expr(expr) = tcx.hir_node(fn_call_id) else { return }; + let def_id = map.enclosing_body_owner(fn_call_id); + let tables = tcx.typeck(def_id); + let Some(return_value_ty) = tables.node_type_opt(value.hir_id) else { return }; + let return_value_ty = self.infcx.resolve_vars_if_possible(return_value_ty); + + // We don't use `ty.peel_refs()` to get the number of `*`s needed to get the root type. + let mut ty = return_value_ty; + let mut count = 0; + while let ty::Ref(_, t, _) = ty.kind() { + ty = *t; + count += 1; + } + if !self.infcx.type_is_copy_modulo_regions(self.param_env, ty) { + return; + } + + // Build a new closure where the return type is an owned value, instead of a ref. + let Some(ty::Closure(did, args)) = + tables.node_type_opt(closure_expr.hir_id).as_ref().map(|ty| ty.kind()) + else { + return; + }; + let sig = args.as_closure().sig(); + let closure_sig_as_fn_ptr_ty = Ty::new_fn_ptr( + tcx, + sig.map_bound(|s| { + let unsafety = hir::Unsafety::Normal; + use rustc_target::spec::abi; + tcx.mk_fn_sig( + [s.inputs()[0]], + s.output().peel_refs(), + s.c_variadic, + unsafety, + abi::Abi::Rust, + ) + }), + ); + let parent_args = GenericArgs::identity_for_item( + tcx, + tcx.typeck_root_def_id(self.mir_def_id().to_def_id()), + ); + let closure_kind = args.as_closure().kind(); + let closure_kind_ty = Ty::from_closure_kind(tcx, closure_kind); + let tupled_upvars_ty = self.infcx.next_ty_var(TypeVariableOrigin { + kind: TypeVariableOriginKind::ClosureSynthetic, + span: closure_expr.span, + }); + let closure_args = ty::ClosureArgs::new( + tcx, + ty::ClosureArgsParts { + parent_args, + closure_kind_ty, + closure_sig_as_fn_ptr_ty, + tupled_upvars_ty, + }, + ); + let closure_ty = Ty::new_closure(tcx, *did, closure_args.args); + let closure_ty = tcx.erase_regions(closure_ty); + + let hir::ExprKind::MethodCall(_, rcvr, args, _) = expr.kind else { return }; + let Some(pos) = args + .iter() + .enumerate() + .find(|(_, arg)| arg.hir_id == closure_expr.hir_id) + .map(|(i, _)| i) + else { + return; + }; + // The found `Self` type of the method call. + let Some(possible_rcvr_ty) = tables.node_type_opt(rcvr.hir_id) else { return }; + + // The `MethodCall` expression is `Res::Err`, so we search for the method on the `rcvr_ty`. + let Some(method) = tcx.lookup_method_for_diagnostic((self.mir_def_id(), expr.hir_id)) + else { + return; + }; + + // Get the type for the parameter corresponding to the argument the closure with the + // lifetime error we had. + let Some(input) = tcx + .fn_sig(method) + .instantiate_identity() + .inputs() + .skip_binder() + // Methods have a `self` arg, so `pos` is actually `+ 1` to match the method call arg. + .get(pos + 1) + else { + return; + }; + + trace!(?input); + + let ty::Param(closure_param) = input.kind() else { return }; + + // Get the arguments for the found method, only specifying that `Self` is the receiver type. + let args = GenericArgs::for_item(tcx, method, |param, _| { + if param.index == 0 { + possible_rcvr_ty.into() + } else if param.index == closure_param.index { + closure_ty.into() + } else { + self.infcx.var_for_def(expr.span, param) + } + }); + + let preds = tcx.predicates_of(method).instantiate(tcx, args); + + let ocx = ObligationCtxt::new(&self.infcx); + ocx.register_obligations(preds.iter().map(|(pred, span)| { + trace!(?pred); + Obligation::misc(tcx, span, self.mir_def_id(), self.param_env, pred) + })); + + if ocx.select_all_or_error().is_empty() { + diag.span_suggestion_verbose( + value.span.shrink_to_lo(), + "dereference the return value", + "*".repeat(count), + Applicability::MachineApplicable, + ); + } + } + #[allow(rustc::diagnostic_outside_of_impl)] #[allow(rustc::untranslatable_diagnostic)] // FIXME: make this translatable fn suggest_move_on_borrowing_closure(&self, diag: &mut Diag<'_>) { diff --git a/compiler/rustc_hir_typeck/src/lib.rs b/compiler/rustc_hir_typeck/src/lib.rs index 700dde184f2d9..476df9ae793f5 100644 --- a/compiler/rustc_hir_typeck/src/lib.rs +++ b/compiler/rustc_hir_typeck/src/lib.rs @@ -56,7 +56,7 @@ use rustc_data_structures::unord::UnordSet; use rustc_errors::{codes::*, struct_span_code_err, ErrorGuaranteed}; use rustc_hir as hir; use rustc_hir::def::{DefKind, Res}; -use rustc_hir::intravisit::Visitor; +use rustc_hir::intravisit::{Map, Visitor}; use rustc_hir::{HirIdMap, Node}; use rustc_hir_analysis::check::check_abi; use rustc_hir_analysis::hir_ty_lowering::HirTyLowerer; @@ -436,6 +436,28 @@ fn fatally_break_rust(tcx: TyCtxt<'_>, span: Span) -> ! { diag.emit() } +pub fn lookup_method_for_diagnostic<'tcx>( + tcx: TyCtxt<'tcx>, + (def_id, hir_id): (LocalDefId, hir::HirId), +) -> Option { + let root_ctxt = TypeckRootCtxt::new(tcx, def_id); + let param_env = tcx.param_env(def_id); + let fn_ctxt = FnCtxt::new(&root_ctxt, param_env, def_id); + let hir::Node::Expr(expr) = tcx.hir().hir_node(hir_id) else { + return None; + }; + let hir::ExprKind::MethodCall(segment, rcvr, _, _) = expr.kind else { + return None; + }; + let tables = tcx.typeck(def_id); + // The found `Self` type of the method call. + let possible_rcvr_ty = tables.node_type_opt(rcvr.hir_id)?; + fn_ctxt + .lookup_method_for_diagnostic(possible_rcvr_ty, segment, expr.span, expr, rcvr) + .ok() + .map(|method| method.def_id) +} + pub fn provide(providers: &mut Providers) { method::provide(providers); *providers = Providers { @@ -443,6 +465,7 @@ pub fn provide(providers: &mut Providers) { diagnostic_only_typeck, has_typeck_results, used_trait_imports, + lookup_method_for_diagnostic: lookup_method_for_diagnostic, ..*providers }; } diff --git a/compiler/rustc_middle/src/query/keys.rs b/compiler/rustc_middle/src/query/keys.rs index c1548eb99f52f..faa137019cb92 100644 --- a/compiler/rustc_middle/src/query/keys.rs +++ b/compiler/rustc_middle/src/query/keys.rs @@ -555,6 +555,19 @@ impl Key for HirId { } } +impl Key for (LocalDefId, HirId) { + type Cache = DefaultCache; + + fn default_span(&self, tcx: TyCtxt<'_>) -> Span { + tcx.hir().span(self.1) + } + + #[inline(always)] + fn key_as_def_id(&self) -> Option { + Some(self.0.into()) + } +} + impl<'tcx> Key for (ValidityRequirement, ty::ParamEnvAnd<'tcx, Ty<'tcx>>) { type Cache = DefaultCache; diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index 5ef7a20f460ed..394515f091f27 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -983,6 +983,9 @@ rustc_queries! { query diagnostic_only_typeck(key: LocalDefId) -> &'tcx ty::TypeckResults<'tcx> { desc { |tcx| "type-checking `{}`", tcx.def_path_str(key) } } + query lookup_method_for_diagnostic((def_id, hir_id): (LocalDefId, hir::HirId)) -> Option { + desc { |tcx| "lookup_method_for_diagnostics `{}`", tcx.def_path_str(def_id) } + } query used_trait_imports(key: LocalDefId) -> &'tcx UnordSet { desc { |tcx| "finding used_trait_imports `{}`", tcx.def_path_str(key) } diff --git a/tests/ui/closures/return-value-lifetime-error.fixed b/tests/ui/closures/return-value-lifetime-error.fixed new file mode 100644 index 0000000000000..bf1f7e4a6cfd3 --- /dev/null +++ b/tests/ui/closures/return-value-lifetime-error.fixed @@ -0,0 +1,16 @@ +//@ run-rustfix +use std::collections::HashMap; + +fn main() { + let vs = vec![0, 0, 1, 1, 3, 4, 5, 6, 3, 3, 3]; + + let mut counts = HashMap::new(); + for num in vs { + let count = counts.entry(num).or_insert(0); + *count += 1; + } + + let _ = counts.iter().max_by_key(|(_, v)| **v); + //~^ ERROR lifetime may not live long enough + //~| HELP dereference the return value +} diff --git a/tests/ui/closures/return-value-lifetime-error.rs b/tests/ui/closures/return-value-lifetime-error.rs new file mode 100644 index 0000000000000..411c91f413ecc --- /dev/null +++ b/tests/ui/closures/return-value-lifetime-error.rs @@ -0,0 +1,16 @@ +//@ run-rustfix +use std::collections::HashMap; + +fn main() { + let vs = vec![0, 0, 1, 1, 3, 4, 5, 6, 3, 3, 3]; + + let mut counts = HashMap::new(); + for num in vs { + let count = counts.entry(num).or_insert(0); + *count += 1; + } + + let _ = counts.iter().max_by_key(|(_, v)| v); + //~^ ERROR lifetime may not live long enough + //~| HELP dereference the return value +} diff --git a/tests/ui/closures/return-value-lifetime-error.stderr b/tests/ui/closures/return-value-lifetime-error.stderr new file mode 100644 index 0000000000000..a0ad127db2891 --- /dev/null +++ b/tests/ui/closures/return-value-lifetime-error.stderr @@ -0,0 +1,16 @@ +error: lifetime may not live long enough + --> $DIR/return-value-lifetime-error.rs:13:47 + | +LL | let _ = counts.iter().max_by_key(|(_, v)| v); + | ------- ^ returning this value requires that `'1` must outlive `'2` + | | | + | | return type of closure is &'2 &i32 + | has type `&'1 (&i32, &i32)` + | +help: dereference the return value + | +LL | let _ = counts.iter().max_by_key(|(_, v)| **v); + | ++ + +error: aborting due to 1 previous error +