From 5043f61e73472b7030898efad88ea92a1bd79b05 Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Mon, 10 Jun 2024 17:22:35 +0200 Subject: [PATCH] interpret: dyn trait metadata check: equate traits in a proper way --- .../rustc_const_eval/src/interpret/cast.rs | 8 +- .../rustc_const_eval/src/interpret/memory.rs | 17 +++- .../rustc_const_eval/src/interpret/place.rs | 49 --------- .../src/interpret/terminator.rs | 79 ++++++++------- .../rustc_const_eval/src/interpret/traits.rs | 99 +++++++++++++++---- .../src/interpret/validity.rs | 18 ++-- .../rustc_const_eval/src/interpret/visitor.rs | 4 +- ...iri-3541-dyn-vtable-trait-normalization.rs | 40 ++++++++ 8 files changed, 182 insertions(+), 132 deletions(-) create mode 100644 src/tools/miri/tests/pass/issues/issue-miri-3541-dyn-vtable-trait-normalization.rs diff --git a/compiler/rustc_const_eval/src/interpret/cast.rs b/compiler/rustc_const_eval/src/interpret/cast.rs index 19414c72c6a42..fc64835a54753 100644 --- a/compiler/rustc_const_eval/src/interpret/cast.rs +++ b/compiler/rustc_const_eval/src/interpret/cast.rs @@ -401,13 +401,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { let (old_data, old_vptr) = val.to_scalar_pair(); let old_data = old_data.to_pointer(self)?; let old_vptr = old_vptr.to_pointer(self)?; - let (ty, old_trait) = self.get_ptr_vtable(old_vptr)?; - if old_trait != data_a.principal() { - throw_ub!(InvalidVTableTrait { - expected_trait: data_a, - vtable_trait: old_trait, - }); - } + let ty = self.get_ptr_vtable_ty(old_vptr, Some(data_a))?; let new_vptr = self.get_vtable_ptr(ty, data_b.principal())?; self.write_immediate(Immediate::new_dyn_trait(old_data, new_vptr, self), dest) } diff --git a/compiler/rustc_const_eval/src/interpret/memory.rs b/compiler/rustc_const_eval/src/interpret/memory.rs index 521f28b7123ea..3a7223c4e4afa 100644 --- a/compiler/rustc_const_eval/src/interpret/memory.rs +++ b/compiler/rustc_const_eval/src/interpret/memory.rs @@ -867,19 +867,26 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { .ok_or_else(|| err_ub!(InvalidFunctionPointer(Pointer::new(alloc_id, offset))).into()) } - pub fn get_ptr_vtable( + /// Get the dynamic type of the given vtable pointer. + /// If `expected_trait` is `Some`, it must be a vtable for the given trait. + pub fn get_ptr_vtable_ty( &self, ptr: Pointer>, - ) -> InterpResult<'tcx, (Ty<'tcx>, Option>)> { + expected_trait: Option<&'tcx ty::List>>, + ) -> InterpResult<'tcx, Ty<'tcx>> { trace!("get_ptr_vtable({:?})", ptr); let (alloc_id, offset, _tag) = self.ptr_get_alloc_id(ptr)?; if offset.bytes() != 0 { throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset))) } - match self.tcx.try_get_global_alloc(alloc_id) { - Some(GlobalAlloc::VTable(ty, trait_ref)) => Ok((ty, trait_ref)), - _ => throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset))), + let Some(GlobalAlloc::VTable(ty, trait_ref)) = self.tcx.try_get_global_alloc(alloc_id) + else { + throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset))) + }; + if let Some(expected_trait) = expected_trait { + self.check_vtable_for_type(trait_ref, expected_trait)?; } + Ok(ty) } pub fn alloc_mark_immutable(&mut self, id: AllocId) -> InterpResult<'tcx> { diff --git a/compiler/rustc_const_eval/src/interpret/place.rs b/compiler/rustc_const_eval/src/interpret/place.rs index 4a86ec3f57a4a..a0630d7f3b5d2 100644 --- a/compiler/rustc_const_eval/src/interpret/place.rs +++ b/compiler/rustc_const_eval/src/interpret/place.rs @@ -9,7 +9,6 @@ use tracing::{instrument, trace}; use rustc_ast::Mutability; use rustc_middle::mir; -use rustc_middle::ty; use rustc_middle::ty::layout::{LayoutOf, TyAndLayout}; use rustc_middle::ty::Ty; use rustc_middle::{bug, span_bug}; @@ -1017,54 +1016,6 @@ where let layout = self.layout_of(raw.ty)?; Ok(self.ptr_to_mplace(ptr.into(), layout)) } - - /// Turn a place with a `dyn Trait` type into a place with the actual dynamic type. - /// Aso returns the vtable. - pub(super) fn unpack_dyn_trait( - &self, - mplace: &MPlaceTy<'tcx, M::Provenance>, - expected_trait: &'tcx ty::List>, - ) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::Provenance>, Pointer>)> { - assert!( - matches!(mplace.layout.ty.kind(), ty::Dynamic(_, _, ty::Dyn)), - "`unpack_dyn_trait` only makes sense on `dyn*` types" - ); - let vtable = mplace.meta().unwrap_meta().to_pointer(self)?; - let (ty, vtable_trait) = self.get_ptr_vtable(vtable)?; - if expected_trait.principal() != vtable_trait { - throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait }); - } - // This is a kind of transmute, from a place with unsized type and metadata to - // a place with sized type and no metadata. - let layout = self.layout_of(ty)?; - let mplace = - MPlaceTy { mplace: MemPlace { meta: MemPlaceMeta::None, ..mplace.mplace }, layout }; - Ok((mplace, vtable)) - } - - /// Turn a `dyn* Trait` type into an value with the actual dynamic type. - /// Also returns the vtable. - pub(super) fn unpack_dyn_star>( - &self, - val: &P, - expected_trait: &'tcx ty::List>, - ) -> InterpResult<'tcx, (P, Pointer>)> { - assert!( - matches!(val.layout().ty.kind(), ty::Dynamic(_, _, ty::DynStar)), - "`unpack_dyn_star` only makes sense on `dyn*` types" - ); - let data = self.project_field(val, 0)?; - let vtable = self.project_field(val, 1)?; - let vtable = self.read_pointer(&vtable.to_op(self)?)?; - let (ty, vtable_trait) = self.get_ptr_vtable(vtable)?; - if expected_trait.principal() != vtable_trait { - throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait }); - } - // `data` is already the right thing but has the wrong type. So we transmute it. - let layout = self.layout_of(ty)?; - let data = data.transmute(layout, self)?; - Ok((data, vtable)) - } } // Some nodes are used a lot. Make sure they don't unintentionally get bigger. diff --git a/compiler/rustc_const_eval/src/interpret/terminator.rs b/compiler/rustc_const_eval/src/interpret/terminator.rs index cbfe25ca8df51..353c52da0aa87 100644 --- a/compiler/rustc_const_eval/src/interpret/terminator.rs +++ b/compiler/rustc_const_eval/src/interpret/terminator.rs @@ -1,6 +1,7 @@ use std::borrow::Cow; use either::Either; +use rustc_middle::ty::TyCtxt; use tracing::trace; use rustc_middle::span_bug; @@ -827,49 +828,47 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { }; // Obtain the underlying trait we are working on, and the adjusted receiver argument. - let (vptr, dyn_ty, adjusted_receiver) = if let ty::Dynamic(data, _, ty::DynStar) = - receiver_place.layout.ty.kind() - { - let (recv, vptr) = self.unpack_dyn_star(&receiver_place, data)?; - let (dyn_ty, _dyn_trait) = self.get_ptr_vtable(vptr)?; + let (dyn_trait, dyn_ty, adjusted_receiver) = + if let ty::Dynamic(data, _, ty::DynStar) = receiver_place.layout.ty.kind() { + let recv = self.unpack_dyn_star(&receiver_place, data)?; + + (data.principal(), recv.layout.ty, recv.ptr()) + } else { + // Doesn't have to be a `dyn Trait`, but the unsized tail must be `dyn Trait`. + // (For that reason we also cannot use `unpack_dyn_trait`.) + let receiver_tail = self.tcx.struct_tail_erasing_lifetimes( + receiver_place.layout.ty, + self.param_env, + ); + let ty::Dynamic(receiver_trait, _, ty::Dyn) = receiver_tail.kind() else { + span_bug!( + self.cur_span(), + "dynamic call on non-`dyn` type {}", + receiver_tail + ) + }; + assert!(receiver_place.layout.is_unsized()); - (vptr, dyn_ty, recv.ptr()) - } else { - // Doesn't have to be a `dyn Trait`, but the unsized tail must be `dyn Trait`. - // (For that reason we also cannot use `unpack_dyn_trait`.) - let receiver_tail = self - .tcx - .struct_tail_erasing_lifetimes(receiver_place.layout.ty, self.param_env); - let ty::Dynamic(data, _, ty::Dyn) = receiver_tail.kind() else { - span_bug!( - self.cur_span(), - "dynamic call on non-`dyn` type {}", - receiver_tail - ) - }; - assert!(receiver_place.layout.is_unsized()); - - // Get the required information from the vtable. - let vptr = receiver_place.meta().unwrap_meta().to_pointer(self)?; - let (dyn_ty, dyn_trait) = self.get_ptr_vtable(vptr)?; - if dyn_trait != data.principal() { - throw_ub!(InvalidVTableTrait { - expected_trait: data, - vtable_trait: dyn_trait, - }); - } + // Get the required information from the vtable. + let vptr = receiver_place.meta().unwrap_meta().to_pointer(self)?; + let dyn_ty = self.get_ptr_vtable_ty(vptr, Some(receiver_trait))?; - // It might be surprising that we use a pointer as the receiver even if this - // is a by-val case; this works because by-val passing of an unsized `dyn - // Trait` to a function is actually desugared to a pointer. - (vptr, dyn_ty, receiver_place.ptr()) - }; + // It might be surprising that we use a pointer as the receiver even if this + // is a by-val case; this works because by-val passing of an unsized `dyn + // Trait` to a function is actually desugared to a pointer. + (receiver_trait.principal(), dyn_ty, receiver_place.ptr()) + }; // Now determine the actual method to call. We can do that in two different ways and // compare them to ensure everything fits. - let Some(ty::VtblEntry::Method(fn_inst)) = - self.get_vtable_entries(vptr)?.get(idx).copied() - else { + let vtable_entries = if let Some(dyn_trait) = dyn_trait { + let trait_ref = dyn_trait.with_self_ty(*self.tcx, dyn_ty); + let trait_ref = self.tcx.erase_regions(trait_ref); + self.tcx.vtable_entries(trait_ref) + } else { + TyCtxt::COMMON_VTABLE_ENTRIES + }; + let Some(ty::VtblEntry::Method(fn_inst)) = vtable_entries.get(idx).copied() else { // FIXME(fee1-dead) these could be variants of the UB info enum instead of this throw_ub_custom!(fluent::const_eval_dyn_call_not_a_method); }; @@ -974,11 +973,11 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { let place = match place.layout.ty.kind() { ty::Dynamic(data, _, ty::Dyn) => { // Dropping a trait object. Need to find actual drop fn. - self.unpack_dyn_trait(&place, data)?.0 + self.unpack_dyn_trait(&place, data)? } ty::Dynamic(data, _, ty::DynStar) => { // Dropping a `dyn*`. Need to find actual drop fn. - self.unpack_dyn_star(&place, data)?.0 + self.unpack_dyn_star(&place, data)? } _ => { debug_assert_eq!( diff --git a/compiler/rustc_const_eval/src/interpret/traits.rs b/compiler/rustc_const_eval/src/interpret/traits.rs index 244a6ba48a4d3..eab69b85fb1a4 100644 --- a/compiler/rustc_const_eval/src/interpret/traits.rs +++ b/compiler/rustc_const_eval/src/interpret/traits.rs @@ -1,11 +1,14 @@ +use rustc_infer::infer::TyCtxtInferExt; +use rustc_infer::traits::ObligationCause; use rustc_middle::mir::interpret::{InterpResult, Pointer}; use rustc_middle::ty::layout::LayoutOf; -use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_middle::ty::{self, Ty}; use rustc_target::abi::{Align, Size}; +use rustc_trait_selection::traits::ObligationCtxt; use tracing::trace; use super::util::ensure_monomorphic_enough; -use super::{InterpCx, Machine}; +use super::{throw_ub, InterpCx, MPlaceTy, Machine, MemPlaceMeta, OffsetMode, Projectable}; impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { /// Creates a dynamic vtable for the given type and vtable origin. This is used only for @@ -33,28 +36,88 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { Ok(vtable_ptr.into()) } - /// Returns a high-level representation of the entries of the given vtable. - pub fn get_vtable_entries( - &self, - vtable: Pointer>, - ) -> InterpResult<'tcx, &'tcx [ty::VtblEntry<'tcx>]> { - let (ty, poly_trait_ref) = self.get_ptr_vtable(vtable)?; - Ok(if let Some(poly_trait_ref) = poly_trait_ref { - let trait_ref = poly_trait_ref.with_self_ty(*self.tcx, ty); - let trait_ref = self.tcx.erase_regions(trait_ref); - self.tcx.vtable_entries(trait_ref) - } else { - TyCtxt::COMMON_VTABLE_ENTRIES - }) - } - pub fn get_vtable_size_and_align( &self, vtable: Pointer>, ) -> InterpResult<'tcx, (Size, Align)> { - let (ty, _trait_ref) = self.get_ptr_vtable(vtable)?; + let ty = self.get_ptr_vtable_ty(vtable, None)?; let layout = self.layout_of(ty)?; assert!(layout.is_sized(), "there are no vtables for unsized types"); Ok((layout.size, layout.align.abi)) } + + /// Check that the given vtable trait is valid for a pointer/reference/place with the given + /// expected trait type. + pub(super) fn check_vtable_for_type( + &self, + vtable_trait: Option>, + expected_trait: &'tcx ty::List>, + ) -> InterpResult<'tcx> { + // Fast path: if they are equal, it's all fine. + if expected_trait.principal() == vtable_trait { + return Ok(()); + } + if let (Some(expected_trait), Some(vtable_trait)) = + (expected_trait.principal(), vtable_trait) + { + // Slow path: spin up an inference context to check if these traits are sufficiently equal. + let infcx = self.tcx.infer_ctxt().build(); + let ocx = ObligationCtxt::new(&infcx); + let cause = ObligationCause::dummy_with_span(self.cur_span()); + // equate the two trait refs + if ocx.eq(&cause, self.param_env, expected_trait, vtable_trait).is_ok() { + if ocx.select_all_or_error().is_empty() { + // All good. + return Ok(()); + } + } + eprintln!("*not* equal: {expected_trait:?} {vtable_trait:?}"); + } + throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait }); + } + + /// Turn a place with a `dyn Trait` type into a place with the actual dynamic type. + pub(super) fn unpack_dyn_trait( + &self, + mplace: &MPlaceTy<'tcx, M::Provenance>, + expected_trait: &'tcx ty::List>, + ) -> InterpResult<'tcx, MPlaceTy<'tcx, M::Provenance>> { + assert!( + matches!(mplace.layout.ty.kind(), ty::Dynamic(_, _, ty::Dyn)), + "`unpack_dyn_trait` only makes sense on `dyn*` types" + ); + let vtable = mplace.meta().unwrap_meta().to_pointer(self)?; + let ty = self.get_ptr_vtable_ty(vtable, Some(expected_trait))?; + // This is a kind of transmute, from a place with unsized type and metadata to + // a place with sized type and no metadata. + let layout = self.layout_of(ty)?; + let mplace = mplace.offset_with_meta( + Size::ZERO, + OffsetMode::Wrapping, + MemPlaceMeta::None, + layout, + self, + )?; + Ok(mplace) + } + + /// Turn a `dyn* Trait` type into an value with the actual dynamic type. + pub(super) fn unpack_dyn_star>( + &self, + val: &P, + expected_trait: &'tcx ty::List>, + ) -> InterpResult<'tcx, P> { + assert!( + matches!(val.layout().ty.kind(), ty::Dynamic(_, _, ty::DynStar)), + "`unpack_dyn_star` only makes sense on `dyn*` types" + ); + let data = self.project_field(val, 0)?; + let vtable = self.project_field(val, 1)?; + let vtable = self.read_pointer(&vtable.to_op(self)?)?; + let ty = self.get_ptr_vtable_ty(vtable, Some(expected_trait))?; + // `data` is already the right thing but has the wrong type. So we transmute it. + let layout = self.layout_of(ty)?; + let data = data.transmute(layout, self)?; + Ok(data) + } } diff --git a/compiler/rustc_const_eval/src/interpret/validity.rs b/compiler/rustc_const_eval/src/interpret/validity.rs index 3407c7b8c7925..92462e17da8a3 100644 --- a/compiler/rustc_const_eval/src/interpret/validity.rs +++ b/compiler/rustc_const_eval/src/interpret/validity.rs @@ -343,20 +343,16 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValidityVisitor<'rt, 'tcx, M> { match tail.kind() { ty::Dynamic(data, _, ty::Dyn) => { let vtable = meta.unwrap_meta().to_pointer(self.ecx)?; - // Make sure it is a genuine vtable pointer. - let (_dyn_ty, dyn_trait) = try_validation!( - self.ecx.get_ptr_vtable(vtable), + // Make sure it is a genuine vtable pointer for the right trait. + try_validation!( + self.ecx.get_ptr_vtable_ty(vtable, Some(data)), self.path, Ub(DanglingIntPointer(..) | InvalidVTablePointer(..)) => - InvalidVTablePtr { value: format!("{vtable}") } + InvalidVTablePtr { value: format!("{vtable}") }, + Ub(InvalidVTableTrait { expected_trait, vtable_trait }) => { + InvalidMetaWrongTrait { expected_trait, vtable_trait: *vtable_trait } + }, ); - // Make sure it is for the right trait. - if dyn_trait != data.principal() { - throw_validation_failure!( - self.path, - InvalidMetaWrongTrait { expected_trait: data, vtable_trait: dyn_trait } - ); - } } ty::Slice(..) | ty::Str => { let _len = meta.unwrap_meta().to_target_usize(self.ecx)?; diff --git a/compiler/rustc_const_eval/src/interpret/visitor.rs b/compiler/rustc_const_eval/src/interpret/visitor.rs index b812e89854b20..71c057e549b41 100644 --- a/compiler/rustc_const_eval/src/interpret/visitor.rs +++ b/compiler/rustc_const_eval/src/interpret/visitor.rs @@ -95,7 +95,7 @@ pub trait ValueVisitor<'tcx, M: Machine<'tcx>>: Sized { // unsized values are never immediate, so we can assert_mem_place let op = v.to_op(self.ecx())?; let dest = op.assert_mem_place(); - let inner_mplace = self.ecx().unpack_dyn_trait(&dest, data)?.0; + let inner_mplace = self.ecx().unpack_dyn_trait(&dest, data)?; trace!("walk_value: dyn object layout: {:#?}", inner_mplace.layout); // recurse with the inner type return self.visit_field(v, 0, &inner_mplace.into()); @@ -104,7 +104,7 @@ pub trait ValueVisitor<'tcx, M: Machine<'tcx>>: Sized { // DynStar types. Very different from a dyn type (but strangely part of the // same variant in `TyKind`): These are pairs where the 2nd component is the // vtable, and the first component is the data (which must be ptr-sized). - let data = self.ecx().unpack_dyn_star(v, data)?.0; + let data = self.ecx().unpack_dyn_star(v, data)?; return self.visit_field(v, 0, &data); } // Slices do not need special handling here: they have `Array` field diff --git a/src/tools/miri/tests/pass/issues/issue-miri-3541-dyn-vtable-trait-normalization.rs b/src/tools/miri/tests/pass/issues/issue-miri-3541-dyn-vtable-trait-normalization.rs new file mode 100644 index 0000000000000..c46031de2d84f --- /dev/null +++ b/src/tools/miri/tests/pass/issues/issue-miri-3541-dyn-vtable-trait-normalization.rs @@ -0,0 +1,40 @@ +#![feature(ptr_metadata)] +// This test is the result of minimizing the `emplacable` crate to reproduce +// . + +use std::{ops::FnMut, ptr::Pointee}; + +pub type EmplacerFn<'a, T> = dyn for<'b> FnMut(::Metadata) + 'a; + +#[repr(transparent)] +pub struct Emplacer<'a, T>(EmplacerFn<'a, T>) +where + T: ?Sized; + +impl<'a, T> Emplacer<'a, T> +where + T: ?Sized, +{ + pub unsafe fn from_fn<'b>(emplacer_fn: &'b mut EmplacerFn<'a, T>) -> &'b mut Self { + // This used to trigger: + // constructing invalid value: wrong trait in wide pointer vtable: expected + // `std::ops::FnMut(<[std::boxed::Box] as std::ptr::Pointee>::Metadata)`, but encountered + // `std::ops::FnMut<(usize,)>`. + unsafe { &mut *((emplacer_fn as *mut EmplacerFn<'a, T>) as *mut Self) } + } +} + +pub fn box_new_with() +where + T: ?Sized, +{ + let emplacer_closure = &mut |_meta| { + unreachable!(); + }; + + unsafe { Emplacer::::from_fn(emplacer_closure) }; +} + +fn main() { + box_new_with::<[Box]>(); +}