Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide suggestion to dereference closure tail if appropriate #122213

Merged
merged 2 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1469,27 +1469,31 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
let hir = tcx.hir();
let Some(body_id) = tcx.hir_node(self.mir_hir_id()).body_id() else { return };
struct FindUselessClone<'hir> {
tcx: TyCtxt<'hir>,
def_id: DefId,
pub clones: Vec<&'hir hir::Expr<'hir>>,
}
impl<'hir> FindUselessClone<'hir> {
pub fn new() -> Self {
Self { clones: vec![] }
pub fn new(tcx: TyCtxt<'hir>, def_id: DefId) -> Self {
Self { tcx, def_id, clones: vec![] }
}
}

impl<'v> Visitor<'v> for FindUselessClone<'v> {
fn visit_expr(&mut self, ex: &'v hir::Expr<'v>) {
// FIXME: use `lookup_method_for_diagnostic`?
if let hir::ExprKind::MethodCall(segment, _rcvr, args, _span) = ex.kind
&& segment.ident.name == sym::clone
&& args.len() == 0
&& let Some(def_id) = self.def_id.as_local()
&& let Some(method) = self.tcx.lookup_method_for_diagnostic((def_id, ex.hir_id))
&& Some(self.tcx.parent(method)) == self.tcx.lang_items().clone_trait()
{
self.clones.push(ex);
}
hir::intravisit::walk_expr(self, ex);
}
}
let mut expr_finder = FindUselessClone::new();
let mut expr_finder = FindUselessClone::new(tcx, self.mir_def_id().into());

let body = hir.body(body_id).value;
expr_finder.visit_expr(body);
Expand Down
145 changes: 145 additions & 0 deletions compiler/rustc_borrowck/src/diagnostics/region_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ use rustc_middle::ty::{self, RegionVid, Ty};
use rustc_middle::ty::{Region, TyCtxt};
use rustc_span::symbol::{kw, Ident};
use rustc_span::Span;
use rustc_trait_selection::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
use rustc_trait_selection::infer::InferCtxtExt;
use rustc_trait_selection::traits::{Obligation, ObligationCtxt};

use crate::borrowck_errors;
use crate::session_diagnostics::{
Expand Down Expand Up @@ -810,6 +813,7 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, 'tcx> {
self.add_static_impl_trait_suggestion(&mut diag, *fr, fr_name, *outlived_fr);
self.suggest_adding_lifetime_params(&mut diag, *fr, *outlived_fr);
self.suggest_move_on_borrowing_closure(&mut diag);
self.suggest_deref_closure_value(&mut diag);

diag
}
Expand Down Expand Up @@ -1039,6 +1043,147 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, 'tcx> {
suggest_adding_lifetime_params(self.infcx.tcx, sub, ty_sup, ty_sub, diag);
}

#[allow(rustc::diagnostic_outside_of_impl)]
#[allow(rustc::untranslatable_diagnostic)] // FIXME: make this translatable
/// When encountering a lifetime error caused by the return type of a closure, check the
/// corresponding trait bound and see if dereferencing the closure return value would satisfy
/// them. If so, we produce a structured suggestion.
fn suggest_deref_closure_value(&self, diag: &mut Diag<'_>) {
let tcx = self.infcx.tcx;
let map = tcx.hir();

// Get the closure return value and type.
let body_id = map.body_owned_by(self.mir_def_id());
let body = &map.body(body_id);
let value = &body.value.peel_blocks();
let hir::Node::Expr(closure_expr) = tcx.hir_node_by_def_id(self.mir_def_id()) else {
return;
};
let fn_call_id = tcx.parent_hir_id(self.mir_hir_id());
let hir::Node::Expr(expr) = tcx.hir_node(fn_call_id) else { return };
let def_id = map.enclosing_body_owner(fn_call_id);
let tables = tcx.typeck(def_id);
let Some(return_value_ty) = tables.node_type_opt(value.hir_id) else { return };
let return_value_ty = self.infcx.resolve_vars_if_possible(return_value_ty);

// We don't use `ty.peel_refs()` to get the number of `*`s needed to get the root type.
let mut ty = return_value_ty;
let mut count = 0;
while let ty::Ref(_, t, _) = ty.kind() {
ty = *t;
count += 1;
}
if !self.infcx.type_is_copy_modulo_regions(self.param_env, ty) {
return;
}

// Build a new closure where the return type is an owned value, instead of a ref.
let Some(ty::Closure(did, args)) =
tables.node_type_opt(closure_expr.hir_id).as_ref().map(|ty| ty.kind())
else {
return;
};
let sig = args.as_closure().sig();
let closure_sig_as_fn_ptr_ty = Ty::new_fn_ptr(
tcx,
sig.map_bound(|s| {
let unsafety = hir::Unsafety::Normal;
use rustc_target::spec::abi;
tcx.mk_fn_sig(
[s.inputs()[0]],
s.output().peel_refs(),
s.c_variadic,
unsafety,
abi::Abi::Rust,
)
}),
);
let parent_args = GenericArgs::identity_for_item(
tcx,
tcx.typeck_root_def_id(self.mir_def_id().to_def_id()),
);
let closure_kind = args.as_closure().kind();
let closure_kind_ty = Ty::from_closure_kind(tcx, closure_kind);
let tupled_upvars_ty = self.infcx.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::ClosureSynthetic,
span: closure_expr.span,
});
let closure_args = ty::ClosureArgs::new(
tcx,
ty::ClosureArgsParts {
parent_args,
closure_kind_ty,
closure_sig_as_fn_ptr_ty,
tupled_upvars_ty,
},
);
let closure_ty = Ty::new_closure(tcx, *did, closure_args.args);
let closure_ty = tcx.erase_regions(closure_ty);

let hir::ExprKind::MethodCall(_, rcvr, args, _) = expr.kind else { return };
let Some(pos) = args
.iter()
.enumerate()
.find(|(_, arg)| arg.hir_id == closure_expr.hir_id)
.map(|(i, _)| i)
else {
return;
};
// The found `Self` type of the method call.
let Some(possible_rcvr_ty) = tables.node_type_opt(rcvr.hir_id) else { return };

// The `MethodCall` expression is `Res::Err`, so we search for the method on the `rcvr_ty`.
let Some(method) = tcx.lookup_method_for_diagnostic((self.mir_def_id(), expr.hir_id))
else {
return;
};

// Get the type for the parameter corresponding to the argument the closure with the
// lifetime error we had.
let Some(input) = tcx
.fn_sig(method)
.instantiate_identity()
.inputs()
.skip_binder()
// Methods have a `self` arg, so `pos` is actually `+ 1` to match the method call arg.
.get(pos + 1)
else {
return;
};

trace!(?input);

let ty::Param(closure_param) = input.kind() else { return };

// Get the arguments for the found method, only specifying that `Self` is the receiver type.
let args = GenericArgs::for_item(tcx, method, |param, _| {
if param.index == 0 {
possible_rcvr_ty.into()
} else if param.index == closure_param.index {
closure_ty.into()
} else {
self.infcx.var_for_def(expr.span, param)
}
});

let preds = tcx.predicates_of(method).instantiate(tcx, args);

let ocx = ObligationCtxt::new(&self.infcx);
ocx.register_obligations(preds.iter().map(|(pred, span)| {
trace!(?pred);
Obligation::misc(tcx, span, self.mir_def_id(), self.param_env, pred)
}));

if ocx.select_all_or_error().is_empty() {
diag.span_suggestion_verbose(
value.span.shrink_to_lo(),
"dereference the return value",
"*".repeat(count),
Applicability::MachineApplicable,
);
}
}

#[allow(rustc::diagnostic_outside_of_impl)]
#[allow(rustc::untranslatable_diagnostic)] // FIXME: make this translatable
fn suggest_move_on_borrowing_closure(&self, diag: &mut Diag<'_>) {
Expand Down
25 changes: 24 additions & 1 deletion compiler/rustc_hir_typeck/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ use rustc_data_structures::unord::UnordSet;
use rustc_errors::{codes::*, struct_span_code_err, ErrorGuaranteed};
use rustc_hir as hir;
use rustc_hir::def::{DefKind, Res};
use rustc_hir::intravisit::Visitor;
use rustc_hir::intravisit::{Map, Visitor};
use rustc_hir::{HirIdMap, Node};
use rustc_hir_analysis::check::check_abi;
use rustc_hir_analysis::hir_ty_lowering::HirTyLowerer;
Expand Down Expand Up @@ -435,13 +435,36 @@ fn fatally_break_rust(tcx: TyCtxt<'_>, span: Span) -> ! {
diag.emit()
}

pub fn lookup_method_for_diagnostic<'tcx>(
tcx: TyCtxt<'tcx>,
(def_id, hir_id): (LocalDefId, hir::HirId),
) -> Option<DefId> {
let root_ctxt = TypeckRootCtxt::new(tcx, def_id);
let param_env = tcx.param_env(def_id);
let fn_ctxt = FnCtxt::new(&root_ctxt, param_env, def_id);
let hir::Node::Expr(expr) = tcx.hir().hir_node(hir_id) else {
return None;
};
let hir::ExprKind::MethodCall(segment, rcvr, _, _) = expr.kind else {
return None;
};
let tables = tcx.typeck(def_id);
// The found `Self` type of the method call.
let possible_rcvr_ty = tables.node_type_opt(rcvr.hir_id)?;
fn_ctxt
.lookup_method_for_diagnostic(possible_rcvr_ty, segment, expr.span, expr, rcvr)
.ok()
.map(|method| method.def_id)
}

pub fn provide(providers: &mut Providers) {
method::provide(providers);
*providers = Providers {
typeck,
diagnostic_only_typeck,
has_typeck_results,
used_trait_imports,
lookup_method_for_diagnostic: lookup_method_for_diagnostic,
..*providers
};
}
13 changes: 13 additions & 0 deletions compiler/rustc_middle/src/query/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,19 @@ impl Key for HirId {
}
}

impl Key for (LocalDefId, HirId) {
type Cache<V> = DefaultCache<Self, V>;

fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
tcx.hir().span(self.1)
}

#[inline(always)]
fn key_as_def_id(&self) -> Option<DefId> {
Some(self.0.into())
}
}

impl<'tcx> Key for (ValidityRequirement, ty::ParamEnvAnd<'tcx, Ty<'tcx>>) {
type Cache<V> = DefaultCache<Self, V>;

Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,9 @@ rustc_queries! {
query diagnostic_only_typeck(key: LocalDefId) -> &'tcx ty::TypeckResults<'tcx> {
desc { |tcx| "type-checking `{}`", tcx.def_path_str(key) }
}
query lookup_method_for_diagnostic((def_id, hir_id): (LocalDefId, hir::HirId)) -> Option<DefId> {
desc { |tcx| "lookup_method_for_diagnostics `{}`", tcx.def_path_str(def_id) }
}

query used_trait_imports(key: LocalDefId) -> &'tcx UnordSet<LocalDefId> {
desc { |tcx| "finding used_trait_imports `{}`", tcx.def_path_str(key) }
Expand Down
16 changes: 16 additions & 0 deletions tests/ui/closures/return-value-lifetime-error.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//@ run-rustfix
use std::collections::HashMap;

fn main() {
let vs = vec![0, 0, 1, 1, 3, 4, 5, 6, 3, 3, 3];

let mut counts = HashMap::new();
for num in vs {
let count = counts.entry(num).or_insert(0);
*count += 1;
}

let _ = counts.iter().max_by_key(|(_, v)| **v);
//~^ ERROR lifetime may not live long enough
//~| HELP dereference the return value
}
16 changes: 16 additions & 0 deletions tests/ui/closures/return-value-lifetime-error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//@ run-rustfix
use std::collections::HashMap;

fn main() {
let vs = vec![0, 0, 1, 1, 3, 4, 5, 6, 3, 3, 3];

let mut counts = HashMap::new();
for num in vs {
let count = counts.entry(num).or_insert(0);
*count += 1;
}

let _ = counts.iter().max_by_key(|(_, v)| v);
//~^ ERROR lifetime may not live long enough
//~| HELP dereference the return value
}
16 changes: 16 additions & 0 deletions tests/ui/closures/return-value-lifetime-error.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
error: lifetime may not live long enough
--> $DIR/return-value-lifetime-error.rs:13:47
|
LL | let _ = counts.iter().max_by_key(|(_, v)| v);
| ------- ^ returning this value requires that `'1` must outlive `'2`
| | |
| | return type of closure is &'2 &i32
| has type `&'1 (&i32, &i32)`
|
help: dereference the return value
|
LL | let _ = counts.iter().max_by_key(|(_, v)| **v);
| ++

error: aborting due to 1 previous error

Loading