diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index f65dfea04eb00..54adbb35e8708 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -920,6 +920,10 @@ rustc_queries! { desc { |tcx| "looking up const stability of `{}`", tcx.def_path_str(def_id) } } + query should_inherit_track_caller(def_id: DefId) -> bool { + desc { |tcx| "computing should_inherit_track_caller of `{}`", tcx.def_path_str(def_id) } + } + query lookup_deprecation_entry(def_id: DefId) -> Option { desc { |tcx| "checking whether `{}` is deprecated", tcx.def_path_str(def_id) } } diff --git a/compiler/rustc_middle/src/ty/instance.rs b/compiler/rustc_middle/src/ty/instance.rs index 41d953216e0dd..261a19f862e02 100644 --- a/compiler/rustc_middle/src/ty/instance.rs +++ b/compiler/rustc_middle/src/ty/instance.rs @@ -227,8 +227,9 @@ impl<'tcx> InstanceDef<'tcx> { pub fn requires_caller_location(&self, tcx: TyCtxt<'_>) -> bool { match *self { - InstanceDef::Item(def) => { - tcx.codegen_fn_attrs(def.did).flags.contains(CodegenFnAttrFlags::TRACK_CALLER) + InstanceDef::Item(ty::WithOptConstParam { did: def_id, .. }) + | InstanceDef::Virtual(def_id, _) => { + tcx.codegen_fn_attrs(def_id).flags.contains(CodegenFnAttrFlags::TRACK_CALLER) } _ => false, } @@ -403,7 +404,7 @@ impl<'tcx> Instance<'tcx> { def_id: DefId, substs: SubstsRef<'tcx>, ) -> Option> { - debug!("resolve(def_id={:?}, substs={:?})", def_id, substs); + debug!("resolve_for_vtable(def_id={:?}, substs={:?})", def_id, substs); let fn_sig = tcx.fn_sig(def_id); let is_vtable_shim = !fn_sig.inputs().skip_binder().is_empty() && fn_sig.input(0).skip_binder().is_param(0) @@ -412,7 +413,50 @@ impl<'tcx> Instance<'tcx> { debug!(" => associated item with unsizeable self: Self"); Some(Instance { def: InstanceDef::VtableShim(def_id), substs }) } else { - Instance::resolve_for_fn_ptr(tcx, param_env, def_id, substs) + Instance::resolve(tcx, param_env, def_id, substs).ok().flatten().map(|mut resolved| { + match resolved.def { + InstanceDef::Item(def) => { + // We need to generate a shim when we cannot guarantee that + // the caller of a trait object method will be aware of + // `#[track_caller]` - this ensures that the caller + // and callee ABI will always match. + // + // The shim is generated when all of these conditions are met: + // + // 1) The underlying method expects a caller location parameter + // in the ABI + if resolved.def.requires_caller_location(tcx) + // 2) The caller location parameter comes from having `#[track_caller]` + // on the implementation, and *not* on the trait method. + && !tcx.should_inherit_track_caller(def.did) + // If the method implementation comes from the trait definition itself + // (e.g. `trait Foo { #[track_caller] my_fn() { /* impl */ } }`), + // then we don't need to generate a shim. This check is needed because + // `should_inherit_track_caller` returns `false` if our method + // implementation comes from the trait block, and not an impl block + && !matches!( + tcx.opt_associated_item(def.did), + Some(ty::AssocItem { + container: ty::AssocItemContainer::TraitContainer(_), + .. + }) + ) + { + debug!( + " => vtable fn pointer created for function with #[track_caller]" + ); + resolved.def = InstanceDef::ReifyShim(def.did); + } + } + InstanceDef::Virtual(def_id, _) => { + debug!(" => vtable fn pointer created for virtual call"); + resolved.def = InstanceDef::ReifyShim(def_id); + } + _ => {} + } + + resolved + }) } } diff --git a/compiler/rustc_typeck/src/collect.rs b/compiler/rustc_typeck/src/collect.rs index 5d83375e5a1b8..6edb1f145b48e 100644 --- a/compiler/rustc_typeck/src/collect.rs +++ b/compiler/rustc_typeck/src/collect.rs @@ -93,6 +93,7 @@ pub fn provide(providers: &mut Providers) { generator_kind, codegen_fn_attrs, collect_mod_item_types, + should_inherit_track_caller, ..*providers }; } @@ -2652,7 +2653,7 @@ fn codegen_fn_attrs(tcx: TyCtxt<'_>, id: DefId) -> CodegenFnAttrs { let attrs = tcx.get_attrs(id); let mut codegen_fn_attrs = CodegenFnAttrs::new(); - if should_inherit_track_caller(tcx, id) { + if tcx.should_inherit_track_caller(id) { codegen_fn_attrs.flags |= CodegenFnAttrFlags::TRACK_CALLER; } diff --git a/src/test/ui/rfc-2091-track-caller/tracked-trait-obj.rs b/src/test/ui/rfc-2091-track-caller/tracked-trait-obj.rs index 3b2a2238fa82d..22622e228f547 100644 --- a/src/test/ui/rfc-2091-track-caller/tracked-trait-obj.rs +++ b/src/test/ui/rfc-2091-track-caller/tracked-trait-obj.rs @@ -2,22 +2,57 @@ trait Tracked { #[track_caller] - fn handle(&self) { + fn track_caller_trait_method(&self, line: u32, col: u32) { let location = std::panic::Location::caller(); assert_eq!(location.file(), file!()); - // we only call this via trait object, so the def site should *always* be returned - assert_eq!(location.line(), line!() - 4); - assert_eq!(location.column(), 5); + // The trait method definition is annotated with `#[track_caller]`, + // so caller location information will work through a method + // call on a trait object + assert_eq!(location.line(), line, "Bad line"); + assert_eq!(location.column(), col, "Bad col"); } + + fn track_caller_not_on_trait_method(&self); + + #[track_caller] + fn track_caller_through_self(self: Box, line: u32, col: u32); } -impl Tracked for () {} -impl Tracked for u8 {} +impl Tracked for () { + // We have `#[track_caller]` on the implementation of the method, + // but not on the definition of the method in the trait. Therefore, + // caller location information will *not* work through a method call + // on a trait object. Instead, we will get the location of this method + #[track_caller] + fn track_caller_not_on_trait_method(&self) { + let location = std::panic::Location::caller(); + assert_eq!(location.file(), file!()); + assert_eq!(location.line(), line!() - 3); + assert_eq!(location.column(), 5); + } + + // We don't have a `#[track_caller]` attribute, but + // `#[track_caller]` is present on the trait definition, + // so we'll still get location information + fn track_caller_through_self(self: Box, line: u32, col: u32) { + let location = std::panic::Location::caller(); + assert_eq!(location.file(), file!()); + // The trait method definition is annotated with `#[track_caller]`, + // so caller location information will work through a method + // call on a trait object + assert_eq!(location.line(), line, "Bad line"); + assert_eq!(location.column(), col, "Bad col"); + } +} fn main() { - let tracked: &dyn Tracked = &5u8; - tracked.handle(); + let tracked: &dyn Tracked = &(); + tracked.track_caller_trait_method(line!(), 13); // The column is the start of 'track_caller_trait_method' const TRACKED: &dyn Tracked = &(); - TRACKED.handle(); + TRACKED.track_caller_trait_method(line!(), 13); // The column is the start of 'track_caller_trait_method' + TRACKED.track_caller_not_on_trait_method(); + + let boxed: Box = Box::new(()); + boxed.track_caller_through_self(line!(), 11); // The column is the start of `track_caller_through_self` }