Skip to content

Commit

Permalink
Rollup merge of #130727 - compiler-errors:objects, r=RalfJung
Browse files Browse the repository at this point in the history
Check vtable projections for validity in miri

Currently, miri does not catch when we transmute `dyn Trait<Assoc = A>` to `dyn Trait<Assoc = B>`. This PR implements such a check, and fixes rust-lang/miri#3905.

To do this, we modify `GlobalAlloc::VTable` to contain the *whole* list of `PolyExistentialPredicate`, and then modify `check_vtable_for_type` to validate the `PolyExistentialProjection`s of the vtable, along with the principal trait that was already being validated.

cc ``@RalfJung``
r? ``@lcnr`` or types

I also tweaked the diagnostics a bit.

---

**Open question:** We don't validate the auto traits. You can transmute `dyn Foo` into `dyn Foo + Send`. Should we check that? We currently have a test that *exercises* this as not being UB:

https://github.com/rust-lang/rust/blob/6c6d210089e4589afee37271862b9f88ba1d7755/src/tools/miri/tests/pass/dyn-upcast.rs#L14-L20

I'm not actually sure if we ever decided that's actually UB or not 🤔

We could perhaps still check that the underlying type of the object (i.e. the concrete type that was unsized) implements the auto traits, to catch UB like:

```rust
fn main() {
    let x: &dyn Trait = &std::ptr::null_mut::<()>();
    let _: &(dyn Trait + Send) = std::mem::transmute(x);
    //~^ this vtable is not allocated for a type that is `Send`!
}
```
  • Loading branch information
compiler-errors authored Sep 24, 2024
2 parents c0f1a69 + 702a644 commit ec1ccff
Show file tree
Hide file tree
Showing 22 changed files with 143 additions and 87 deletions.
8 changes: 4 additions & 4 deletions compiler/rustc_codegen_cranelift/src/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,13 @@ pub(crate) fn codegen_const_value<'tcx>(
fx.module.declare_func_in_func(func_id, &mut fx.bcx.func);
fx.bcx.ins().func_addr(fx.pointer_type, local_func_id)
}
GlobalAlloc::VTable(ty, trait_ref) => {
GlobalAlloc::VTable(ty, dyn_ty) => {
let data_id = data_id_for_vtable(
fx.tcx,
&mut fx.constants_cx,
fx.module,
ty,
trait_ref,
dyn_ty.principal(),
);
let local_data_id =
fx.module.declare_data_in_func(data_id, &mut fx.bcx.func);
Expand Down Expand Up @@ -456,8 +456,8 @@ fn define_all_allocs(tcx: TyCtxt<'_>, module: &mut dyn Module, cx: &mut Constant
GlobalAlloc::Memory(target_alloc) => {
data_id_for_alloc_id(cx, module, alloc_id, target_alloc.inner().mutability)
}
GlobalAlloc::VTable(ty, trait_ref) => {
data_id_for_vtable(tcx, cx, module, ty, trait_ref)
GlobalAlloc::VTable(ty, dyn_ty) => {
data_id_for_vtable(tcx, cx, module, ty, dyn_ty.principal())
}
GlobalAlloc::Static(def_id) => {
if tcx.codegen_fn_attrs(def_id).flags.contains(CodegenFnAttrFlags::THREAD_LOCAL)
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_codegen_gcc/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,10 @@ impl<'gcc, 'tcx> ConstCodegenMethods<'tcx> for CodegenCx<'gcc, 'tcx> {
value
}
GlobalAlloc::Function { instance, .. } => self.get_fn_addr(instance),
GlobalAlloc::VTable(ty, trait_ref) => {
GlobalAlloc::VTable(ty, dyn_ty) => {
let alloc = self
.tcx
.global_alloc(self.tcx.vtable_allocation((ty, trait_ref)))
.global_alloc(self.tcx.vtable_allocation((ty, dyn_ty.principal())))
.unwrap_memory();
let init = const_alloc_to_gcc(self, alloc);
self.static_addr_of(init, alloc.inner().align, None)
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_codegen_llvm/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,10 @@ impl<'ll, 'tcx> ConstCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> {
self.get_fn_addr(instance.polymorphize(self.tcx)),
self.data_layout().instruction_address_space,
),
GlobalAlloc::VTable(ty, trait_ref) => {
GlobalAlloc::VTable(ty, dyn_ty) => {
let alloc = self
.tcx
.global_alloc(self.tcx.vtable_allocation((ty, trait_ref)))
.global_alloc(self.tcx.vtable_allocation((ty, dyn_ty.principal())))
.unwrap_memory();
let init = const_alloc_to_llvm(self, alloc, /*static*/ false);
let value = self.static_addr_of(init, alloc.inner().align, None);
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_const_eval/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ const_eval_invalid_vtable_pointer =
using {$pointer} as vtable pointer but it does not point to a vtable
const_eval_invalid_vtable_trait =
using vtable for trait `{$vtable_trait}` but trait `{$expected_trait}` was expected
using vtable for `{$vtable_dyn_type}` but `{$expected_dyn_type}` was expected
const_eval_lazy_lock =
consider wrapping this expression in `std::sync::LazyLock::new(|| ...)`
Expand Down Expand Up @@ -459,7 +459,7 @@ const_eval_validation_invalid_fn_ptr = {$front_matter}: encountered {$value}, bu
const_eval_validation_invalid_ref_meta = {$front_matter}: encountered invalid reference metadata: total size is bigger than largest supported object
const_eval_validation_invalid_ref_slice_meta = {$front_matter}: encountered invalid reference metadata: slice is bigger than largest supported object
const_eval_validation_invalid_vtable_ptr = {$front_matter}: encountered {$value}, but expected a vtable pointer
const_eval_validation_invalid_vtable_trait = {$front_matter}: wrong trait in wide pointer vtable: expected `{$ref_trait}`, but encountered `{$vtable_trait}`
const_eval_validation_invalid_vtable_trait = {$front_matter}: wrong trait in wide pointer vtable: expected `{$expected_dyn_type}`, but encountered `{$vtable_dyn_type}`
const_eval_validation_mutable_ref_to_immutable = {$front_matter}: encountered mutable reference or box pointing to read-only memory
const_eval_validation_never_val = {$front_matter}: encountered a value of the never type `!`
const_eval_validation_null_box = {$front_matter}: encountered a null box
Expand Down
18 changes: 6 additions & 12 deletions compiler/rustc_const_eval/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,12 +522,9 @@ impl<'a> ReportErrorExt for UndefinedBehaviorInfo<'a> {
UnterminatedCString(ptr) | InvalidFunctionPointer(ptr) | InvalidVTablePointer(ptr) => {
diag.arg("pointer", ptr);
}
InvalidVTableTrait { expected_trait, vtable_trait } => {
diag.arg("expected_trait", expected_trait.to_string());
diag.arg(
"vtable_trait",
vtable_trait.map(|t| t.to_string()).unwrap_or_else(|| format!("<trivial>")),
);
InvalidVTableTrait { expected_dyn_type, vtable_dyn_type } => {
diag.arg("expected_dyn_type", expected_dyn_type.to_string());
diag.arg("vtable_dyn_type", vtable_dyn_type.to_string());
}
PointerUseAfterFree(alloc_id, msg) => {
diag.arg("alloc_id", alloc_id)
Expand Down Expand Up @@ -777,12 +774,9 @@ impl<'tcx> ReportErrorExt for ValidationErrorInfo<'tcx> {
DanglingPtrNoProvenance { pointer, .. } => {
err.arg("pointer", pointer);
}
InvalidMetaWrongTrait { expected_trait: ref_trait, vtable_trait } => {
err.arg("ref_trait", ref_trait.to_string());
err.arg(
"vtable_trait",
vtable_trait.map(|t| t.to_string()).unwrap_or_else(|| format!("<trivial>")),
);
InvalidMetaWrongTrait { vtable_dyn_type, expected_dyn_type } => {
err.arg("vtable_dyn_type", vtable_dyn_type.to_string());
err.arg("expected_dyn_type", expected_dyn_type.to_string());
}
NullPtr { .. }
| ConstRefToMutable
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_const_eval/src/interpret/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
CastKind::DynStar => {
if let ty::Dynamic(data, _, ty::DynStar) = cast_ty.kind() {
// Initial cast from sized to dyn trait
let vtable = self.get_vtable_ptr(src.layout.ty, data.principal())?;
let vtable = self.get_vtable_ptr(src.layout.ty, data)?;
let vtable = Scalar::from_maybe_pointer(vtable, self);
let data = self.read_immediate(src)?.to_scalar();
let _assert_pointer_like = data.to_pointer(self)?;
Expand Down Expand Up @@ -446,12 +446,12 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
}

// Get the destination trait vtable and return that.
let new_vptr = self.get_vtable_ptr(ty, data_b.principal())?;
let new_vptr = self.get_vtable_ptr(ty, data_b)?;
self.write_immediate(Immediate::new_dyn_trait(old_data, new_vptr, self), dest)
}
(_, &ty::Dynamic(data, _, ty::Dyn)) => {
// Initial cast from sized to dyn trait
let vtable = self.get_vtable_ptr(src_pointee_ty, data.principal())?;
let vtable = self.get_vtable_ptr(src_pointee_ty, data)?;
let ptr = self.read_pointer(src)?;
let val = Immediate::new_dyn_trait(ptr, vtable, &*self.tcx);
self.write_immediate(val, dest)
Expand Down
14 changes: 6 additions & 8 deletions compiler/rustc_const_eval/src/interpret/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -943,12 +943,13 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
if offset.bytes() != 0 {
throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset)))
}
let Some(GlobalAlloc::VTable(ty, vtable_trait)) = self.tcx.try_get_global_alloc(alloc_id)
let Some(GlobalAlloc::VTable(ty, vtable_dyn_type)) =
self.tcx.try_get_global_alloc(alloc_id)
else {
throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset)))
};
if let Some(expected_trait) = expected_trait {
self.check_vtable_for_type(vtable_trait, expected_trait)?;
if let Some(expected_dyn_type) = expected_trait {
self.check_vtable_for_type(vtable_dyn_type, expected_dyn_type)?;
}
Ok(ty)
}
Expand Down Expand Up @@ -1113,11 +1114,8 @@ impl<'a, 'tcx, M: Machine<'tcx>> std::fmt::Debug for DumpAllocs<'a, 'tcx, M> {
Some(GlobalAlloc::Function { instance, .. }) => {
write!(fmt, " (fn: {instance})")?;
}
Some(GlobalAlloc::VTable(ty, Some(trait_ref))) => {
write!(fmt, " (vtable: impl {trait_ref} for {ty})")?;
}
Some(GlobalAlloc::VTable(ty, None)) => {
write!(fmt, " (vtable: impl <auto trait> for {ty})")?;
Some(GlobalAlloc::VTable(ty, dyn_ty)) => {
write!(fmt, " (vtable: impl {dyn_ty} for {ty})")?;
}
Some(GlobalAlloc::Static(did)) => {
write!(fmt, " (static: {})", self.ecx.tcx.def_path_str(did))?;
Expand Down
67 changes: 47 additions & 20 deletions compiler/rustc_const_eval/src/interpret/traits.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use rustc_middle::mir::interpret::{InterpResult, Pointer};
use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{self, Ty, TyCtxt, VtblEntry};
use rustc_middle::ty::{self, ExistentialPredicateStableCmpExt, Ty, TyCtxt, VtblEntry};
use rustc_target::abi::{Align, Size};
use tracing::trace;

Expand All @@ -11,26 +11,25 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
/// Creates a dynamic vtable for the given type and vtable origin. This is used only for
/// objects.
///
/// The `trait_ref` encodes the erased self type. Hence, if we are making an object `Foo<Trait>`
/// from a value of type `Foo<T>`, then `trait_ref` would map `T: Trait`. `None` here means that
/// this is an auto trait without any methods, so we only need the basic vtable (drop, size,
/// align).
/// The `dyn_ty` encodes the erased self type. Hence, if we are making an object
/// `Foo<dyn Trait<Assoc = A> + Send>` from a value of type `Foo<T>`, then `dyn_ty`
/// would be `Trait<Assoc = A> + Send`. If this list doesn't have a principal trait ref,
/// we only need the basic vtable prefix (drop, size, align).
pub fn get_vtable_ptr(
&self,
ty: Ty<'tcx>,
poly_trait_ref: Option<ty::PolyExistentialTraitRef<'tcx>>,
dyn_ty: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
) -> InterpResult<'tcx, Pointer<Option<M::Provenance>>> {
trace!("get_vtable(trait_ref={:?})", poly_trait_ref);
trace!("get_vtable(ty={ty:?}, dyn_ty={dyn_ty:?})");

let (ty, poly_trait_ref) = self.tcx.erase_regions((ty, poly_trait_ref));
let (ty, dyn_ty) = self.tcx.erase_regions((ty, dyn_ty));

// All vtables must be monomorphic, bail out otherwise.
ensure_monomorphic_enough(*self.tcx, ty)?;
ensure_monomorphic_enough(*self.tcx, poly_trait_ref)?;
ensure_monomorphic_enough(*self.tcx, dyn_ty)?;

let salt = M::get_global_alloc_salt(self, None);
let vtable_symbolic_allocation =
self.tcx.reserve_and_set_vtable_alloc(ty, poly_trait_ref, salt);
let vtable_symbolic_allocation = self.tcx.reserve_and_set_vtable_alloc(ty, dyn_ty, salt);
let vtable_ptr = self.global_root_pointer(Pointer::from(vtable_symbolic_allocation))?;
Ok(vtable_ptr.into())
}
Expand Down Expand Up @@ -64,17 +63,45 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
/// expected trait type.
pub(super) fn check_vtable_for_type(
&self,
vtable_trait: Option<ty::PolyExistentialTraitRef<'tcx>>,
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
vtable_dyn_type: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
expected_dyn_type: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
) -> InterpResult<'tcx> {
let eq = match (expected_trait.principal(), vtable_trait) {
(Some(a), Some(b)) => self.eq_in_param_env(a, b),
(None, None) => true,
_ => false,
};
if !eq {
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
// We check validity by comparing the lists of predicates for equality. We *could* instead
// check that the dynamic type to which the vtable belongs satisfies all the expected
// predicates, but that would likely be a lot slower and seems unnecessarily permissive.

// FIXME: we are skipping auto traits for now, but might revisit this in the future.
let mut sorted_vtable: Vec<_> = vtable_dyn_type.without_auto_traits().collect();
let mut sorted_expected: Vec<_> = expected_dyn_type.without_auto_traits().collect();
// `skip_binder` here is okay because `stable_cmp` doesn't look at binders
sorted_vtable.sort_by(|a, b| a.skip_binder().stable_cmp(*self.tcx, &b.skip_binder()));
sorted_vtable.dedup();
sorted_expected.sort_by(|a, b| a.skip_binder().stable_cmp(*self.tcx, &b.skip_binder()));
sorted_expected.dedup();

if sorted_vtable.len() != sorted_expected.len() {
throw_ub!(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type });
}

for (a_pred, b_pred) in std::iter::zip(sorted_vtable, sorted_expected) {
let is_eq = match (a_pred.skip_binder(), b_pred.skip_binder()) {
(
ty::ExistentialPredicate::Trait(a_data),
ty::ExistentialPredicate::Trait(b_data),
) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)),

(
ty::ExistentialPredicate::Projection(a_data),
ty::ExistentialPredicate::Projection(b_data),
) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)),

_ => false,
};
if !is_eq {
throw_ub!(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type });
}
}

Ok(())
}

Expand Down
8 changes: 4 additions & 4 deletions compiler/rustc_const_eval/src/interpret/validity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,8 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValidityVisitor<'rt, 'tcx, M> {
self.path,
Ub(DanglingIntPointer{ .. } | InvalidVTablePointer(..)) =>
InvalidVTablePtr { value: format!("{vtable}") },
Ub(InvalidVTableTrait { expected_trait, vtable_trait }) => {
InvalidMetaWrongTrait { expected_trait, vtable_trait: *vtable_trait }
Ub(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type }) => {
InvalidMetaWrongTrait { vtable_dyn_type, expected_dyn_type }
},
);
}
Expand Down Expand Up @@ -1281,8 +1281,8 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValueVisitor<'tcx, M> for ValidityVisitor<'rt,
self.path,
// It's not great to catch errors here, since we can't give a very good path,
// but it's better than ICEing.
Ub(InvalidVTableTrait { expected_trait, vtable_trait }) => {
InvalidMetaWrongTrait { expected_trait, vtable_trait: *vtable_trait }
Ub(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type }) => {
InvalidMetaWrongTrait { vtable_dyn_type, expected_dyn_type: *expected_dyn_type }
},
);
}
Expand Down
12 changes: 8 additions & 4 deletions compiler/rustc_middle/src/mir/interpret/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,10 @@ pub enum UndefinedBehaviorInfo<'tcx> {
InvalidVTablePointer(Pointer<AllocId>),
/// Using a vtable for the wrong trait.
InvalidVTableTrait {
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
vtable_trait: Option<ty::PolyExistentialTraitRef<'tcx>>,
/// The vtable that was actually referenced by the wide pointer metadata.
vtable_dyn_type: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
/// The vtable that was expected at the point in MIR that it was accessed.
expected_dyn_type: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
},
/// Using a string that is not valid UTF-8,
InvalidStr(std::str::Utf8Error),
Expand Down Expand Up @@ -479,8 +481,10 @@ pub enum ValidationErrorKind<'tcx> {
value: String,
},
InvalidMetaWrongTrait {
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
vtable_trait: Option<ty::PolyExistentialTraitRef<'tcx>>,
/// The vtable that was actually referenced by the wide pointer metadata.
vtable_dyn_type: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
/// The vtable that was expected at the point in MIR that it was accessed.
expected_dyn_type: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
},
InvalidMetaSliceTooLarge {
ptr_kind: PointerKind,
Expand Down
16 changes: 9 additions & 7 deletions compiler/rustc_middle/src/mir/interpret/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,8 @@ impl<'s> AllocDecodingSession<'s> {
}
AllocDiscriminant::VTable => {
trace!("creating vtable alloc ID");
let ty = <Ty<'_> as Decodable<D>>::decode(decoder);
let poly_trait_ref =
<Option<ty::PolyExistentialTraitRef<'_>> as Decodable<D>>::decode(decoder);
let ty = Decodable::decode(decoder);
let poly_trait_ref = Decodable::decode(decoder);
trace!("decoded vtable alloc instance: {ty:?}, {poly_trait_ref:?}");
decoder.interner().reserve_and_set_vtable_alloc(ty, poly_trait_ref, CTFE_ALLOC_SALT)
}
Expand All @@ -259,7 +258,10 @@ pub enum GlobalAlloc<'tcx> {
/// The alloc ID is used as a function pointer.
Function { instance: Instance<'tcx> },
/// This alloc ID points to a symbolic (not-reified) vtable.
VTable(Ty<'tcx>, Option<ty::PolyExistentialTraitRef<'tcx>>),
/// We remember the full dyn type, not just the principal trait, so that
/// const-eval and Miri can detect UB due to invalid transmutes of
/// `dyn Trait` types.
VTable(Ty<'tcx>, &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>),
/// The alloc ID points to a "lazy" static variable that did not get computed (yet).
/// This is also used to break the cycle in recursive statics.
Static(DefId),
Expand Down Expand Up @@ -293,7 +295,7 @@ impl<'tcx> GlobalAlloc<'tcx> {
#[inline]
pub fn unwrap_vtable(&self) -> (Ty<'tcx>, Option<ty::PolyExistentialTraitRef<'tcx>>) {
match *self {
GlobalAlloc::VTable(ty, poly_trait_ref) => (ty, poly_trait_ref),
GlobalAlloc::VTable(ty, dyn_ty) => (ty, dyn_ty.principal()),
_ => bug!("expected vtable, got {:?}", self),
}
}
Expand Down Expand Up @@ -398,10 +400,10 @@ impl<'tcx> TyCtxt<'tcx> {
pub fn reserve_and_set_vtable_alloc(
self,
ty: Ty<'tcx>,
poly_trait_ref: Option<ty::PolyExistentialTraitRef<'tcx>>,
dyn_ty: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
salt: usize,
) -> AllocId {
self.reserve_and_set_dedup(GlobalAlloc::VTable(ty, poly_trait_ref), salt)
self.reserve_and_set_dedup(GlobalAlloc::VTable(ty, dyn_ty), salt)
}

/// Interns the `Allocation` and return a new `AllocId`, even if there's already an identical
Expand Down
Loading

0 comments on commit ec1ccff

Please sign in to comment.