Skip to content

Commit

Permalink
Rollup merge of #125041 - scottmcm:gvn-for-from-raw-parts, r=cjgillot
Browse files Browse the repository at this point in the history
Enable GVN for `AggregateKind::RawPtr`

Looks like I was worried for nothing; this seems like it's much easier than I was originally thinking it would be.
r? `@cjgillot`

This should be useful for `x[..4]`-like things, should those start inlining enough to expose the lengths.
  • Loading branch information
jieyouxu authored Jun 9, 2024
2 parents 7bb0ef4 + 021ccf6 commit f000b42
Show file tree
Hide file tree
Showing 14 changed files with 519 additions and 26 deletions.
99 changes: 87 additions & 12 deletions compiler/rustc_mir_transform/src/gvn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@
//! that contain `AllocId`s.
use rustc_const_eval::const_eval::DummyMachine;
use rustc_const_eval::interpret::{intern_const_alloc_for_constprop, MemoryKind};
use rustc_const_eval::interpret::{ImmTy, InterpCx, OpTy, Projectable, Scalar};
use rustc_const_eval::interpret::{intern_const_alloc_for_constprop, MemPlaceMeta, MemoryKind};
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable, Scalar};
use rustc_data_structures::fx::FxIndexSet;
use rustc_data_structures::graph::dominators::Dominators;
use rustc_hir::def::DefKind;
Expand All @@ -99,7 +99,7 @@ use rustc_middle::ty::layout::{HasParamEnv, LayoutOf};
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_span::def_id::DefId;
use rustc_span::DUMMY_SP;
use rustc_target::abi::{self, Abi, Size, VariantIdx, FIRST_VARIANT};
use rustc_target::abi::{self, Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT};
use smallvec::SmallVec;
use std::borrow::Cow;

Expand Down Expand Up @@ -177,6 +177,12 @@ enum AggregateTy<'tcx> {
Array,
Tuple,
Def(DefId, ty::GenericArgsRef<'tcx>),
RawPtr {
/// Needed for cast propagation.
data_pointer_ty: Ty<'tcx>,
/// The data pointer can be anything thin, so doesn't determine the output.
output_pointer_ty: Ty<'tcx>,
},
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -385,11 +391,22 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
AggregateTy::Def(def_id, args) => {
self.tcx.type_of(def_id).instantiate(self.tcx, args)
}
AggregateTy::RawPtr { output_pointer_ty, .. } => output_pointer_ty,
};
let variant = if ty.is_enum() { Some(variant) } else { None };
let ty = self.ecx.layout_of(ty).ok()?;
if ty.is_zst() {
ImmTy::uninit(ty).into()
} else if matches!(kind, AggregateTy::RawPtr { .. }) {
// Pointers don't have fields, so don't `project_field` them.
let data = self.ecx.read_pointer(fields[0]).ok()?;
let meta = if fields[1].layout.is_zst() {
MemPlaceMeta::None
} else {
MemPlaceMeta::Meta(self.ecx.read_scalar(fields[1]).ok()?)
};
let ptr_imm = Immediate::new_pointer_with_meta(data, meta, &self.ecx);
ImmTy::from_immediate(ptr_imm, ty).into()
} else if matches!(ty.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) {
let dest = self.ecx.allocate(ty, MemoryKind::Stack).ok()?;
let variant_dest = if let Some(variant) = variant {
Expand Down Expand Up @@ -864,10 +881,10 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
rvalue: &mut Rvalue<'tcx>,
location: Location,
) -> Option<VnIndex> {
let Rvalue::Aggregate(box ref kind, ref mut fields) = *rvalue else { bug!() };
let Rvalue::Aggregate(box ref kind, ref mut field_ops) = *rvalue else { bug!() };

let tcx = self.tcx;
if fields.is_empty() {
if field_ops.is_empty() {
let is_zst = match *kind {
AggregateKind::Array(..)
| AggregateKind::Tuple
Expand All @@ -886,13 +903,13 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
}
}

let (ty, variant_index) = match *kind {
let (mut ty, variant_index) = match *kind {
AggregateKind::Array(..) => {
assert!(!fields.is_empty());
assert!(!field_ops.is_empty());
(AggregateTy::Array, FIRST_VARIANT)
}
AggregateKind::Tuple => {
assert!(!fields.is_empty());
assert!(!field_ops.is_empty());
(AggregateTy::Tuple, FIRST_VARIANT)
}
AggregateKind::Closure(did, args)
Expand All @@ -903,15 +920,49 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
}
// Do not track unions.
AggregateKind::Adt(_, _, _, _, Some(_)) => return None,
// FIXME: Do the extra work to GVN `from_raw_parts`
AggregateKind::RawPtr(..) => return None,
AggregateKind::RawPtr(pointee_ty, mtbl) => {
assert_eq!(field_ops.len(), 2);
let data_pointer_ty = field_ops[FieldIdx::ZERO].ty(self.local_decls, self.tcx);
let output_pointer_ty = Ty::new_ptr(self.tcx, pointee_ty, mtbl);
(AggregateTy::RawPtr { data_pointer_ty, output_pointer_ty }, FIRST_VARIANT)
}
};

let fields: Option<Vec<_>> = fields
let fields: Option<Vec<_>> = field_ops
.iter_mut()
.map(|op| self.simplify_operand(op, location).or_else(|| self.new_opaque()))
.collect();
let fields = fields?;
let mut fields = fields?;

if let AggregateTy::RawPtr { data_pointer_ty, output_pointer_ty } = &mut ty {
let mut was_updated = false;

// Any thin pointer of matching mutability is fine as the data pointer.
while let Value::Cast {
kind: CastKind::PtrToPtr,
value: cast_value,
from: cast_from,
to: _,
} = self.get(fields[0])
&& let ty::RawPtr(from_pointee_ty, from_mtbl) = cast_from.kind()
&& let ty::RawPtr(_, output_mtbl) = output_pointer_ty.kind()
&& from_mtbl == output_mtbl
&& from_pointee_ty.is_sized(self.tcx, self.param_env)
{
fields[0] = *cast_value;
*data_pointer_ty = *cast_from;
was_updated = true;
}

if was_updated {
if let Some(const_) = self.try_as_constant(fields[0]) {
field_ops[FieldIdx::ZERO] = Operand::Constant(Box::new(const_));
} else if let Some(local) = self.try_as_local(fields[0], location) {
field_ops[FieldIdx::ZERO] = Operand::Copy(Place::from(local));
self.reused_locals.insert(local);
}
}
}

if let AggregateTy::Array = ty
&& fields.len() > 4
Expand Down Expand Up @@ -943,6 +994,9 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
(UnOp::Not, Value::BinaryOp(BinOp::Ne, lhs, rhs)) => {
Value::BinaryOp(BinOp::Eq, *lhs, *rhs)
}
(UnOp::PtrMetadata, Value::Aggregate(AggregateTy::RawPtr { .. }, _, fields)) => {
return Some(fields[1]);
}
_ => return None,
};

Expand Down Expand Up @@ -1094,6 +1148,23 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
return self.new_opaque();
}

let mut was_updated = false;

// If that cast just casts away the metadata again,
if let PtrToPtr = kind
&& let Value::Aggregate(AggregateTy::RawPtr { data_pointer_ty, .. }, _, fields) =
self.get(value)
&& let ty::RawPtr(to_pointee, _) = to.kind()
&& to_pointee.is_sized(self.tcx, self.param_env)
{
from = *data_pointer_ty;
value = fields[0];
was_updated = true;
if *data_pointer_ty == to {
return Some(fields[0]);
}
}

if let PtrToPtr | PointerCoercion(MutToConstPointer) = kind
&& let Value::Cast { kind: inner_kind, value: inner_value, from: inner_from, to: _ } =
*self.get(value)
Expand All @@ -1102,9 +1173,13 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
from = inner_from;
value = inner_value;
*kind = PtrToPtr;
was_updated = true;
if inner_from == to {
return Some(inner_value);
}
}

if was_updated {
if let Some(const_) = self.try_as_constant(value) {
*operand = Operand::Constant(Box::new(const_));
} else if let Some(local) = self.try_as_local(value, location) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
- // MIR for `casts_before_aggregate_raw_ptr` before GVN
+ // MIR for `casts_before_aggregate_raw_ptr` after GVN

fn casts_before_aggregate_raw_ptr(_1: *const u32) -> *const [u8] {
debug x => _1;
let mut _0: *const [u8];
let _2: *const [u8; 4];
let mut _3: *const u32;
let mut _5: *const [u8; 4];
let mut _7: *const u8;
let mut _8: *const ();
scope 1 {
debug x => _2;
let _4: *const u8;
scope 2 {
debug x => _4;
let _6: *const ();
scope 3 {
debug x => _6;
}
}
}

bb0: {
- StorageLive(_2);
+ nop;
StorageLive(_3);
_3 = _1;
- _2 = move _3 as *const [u8; 4] (PtrToPtr);
+ _2 = _1 as *const [u8; 4] (PtrToPtr);
StorageDead(_3);
- StorageLive(_4);
+ nop;
StorageLive(_5);
_5 = _2;
- _4 = move _5 as *const u8 (PtrToPtr);
+ _4 = _1 as *const u8 (PtrToPtr);
StorageDead(_5);
- StorageLive(_6);
+ nop;
StorageLive(_7);
_7 = _4;
- _6 = move _7 as *const () (PtrToPtr);
+ _6 = _1 as *const () (PtrToPtr);
StorageDead(_7);
StorageLive(_8);
_8 = _6;
- _0 = *const [u8] from (move _8, const 4_usize);
+ _0 = *const [u8] from (_1, const 4_usize);
StorageDead(_8);
- StorageDead(_6);
- StorageDead(_4);
- StorageDead(_2);
+ nop;
+ nop;
+ nop;
return;
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
- // MIR for `casts_before_aggregate_raw_ptr` before GVN
+ // MIR for `casts_before_aggregate_raw_ptr` after GVN

fn casts_before_aggregate_raw_ptr(_1: *const u32) -> *const [u8] {
debug x => _1;
let mut _0: *const [u8];
let _2: *const [u8; 4];
let mut _3: *const u32;
let mut _5: *const [u8; 4];
let mut _7: *const u8;
let mut _8: *const ();
scope 1 {
debug x => _2;
let _4: *const u8;
scope 2 {
debug x => _4;
let _6: *const ();
scope 3 {
debug x => _6;
}
}
}

bb0: {
- StorageLive(_2);
+ nop;
StorageLive(_3);
_3 = _1;
- _2 = move _3 as *const [u8; 4] (PtrToPtr);
+ _2 = _1 as *const [u8; 4] (PtrToPtr);
StorageDead(_3);
- StorageLive(_4);
+ nop;
StorageLive(_5);
_5 = _2;
- _4 = move _5 as *const u8 (PtrToPtr);
+ _4 = _1 as *const u8 (PtrToPtr);
StorageDead(_5);
- StorageLive(_6);
+ nop;
StorageLive(_7);
_7 = _4;
- _6 = move _7 as *const () (PtrToPtr);
+ _6 = _1 as *const () (PtrToPtr);
StorageDead(_7);
StorageLive(_8);
_8 = _6;
- _0 = *const [u8] from (move _8, const 4_usize);
+ _0 = *const [u8] from (_1, const 4_usize);
StorageDead(_8);
- StorageDead(_6);
- StorageDead(_4);
- StorageDead(_2);
+ nop;
+ nop;
+ nop;
return;
}
}

32 changes: 32 additions & 0 deletions tests/mir-opt/gvn.meta_of_ref_to_slice.GVN.panic-abort.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
- // MIR for `meta_of_ref_to_slice` before GVN
+ // MIR for `meta_of_ref_to_slice` after GVN

fn meta_of_ref_to_slice(_1: *const i32) -> usize {
debug x => _1;
let mut _0: usize;
let _2: *const [i32];
let mut _3: *const i32;
let mut _4: *const [i32];
scope 1 {
debug ptr => _2;
}

bb0: {
- StorageLive(_2);
+ nop;
StorageLive(_3);
_3 = _1;
- _2 = *const [i32] from (move _3, const 1_usize);
+ _2 = *const [i32] from (_1, const 1_usize);
StorageDead(_3);
StorageLive(_4);
_4 = _2;
- _0 = PtrMetadata(move _4);
+ _0 = const 1_usize;
StorageDead(_4);
- StorageDead(_2);
+ nop;
return;
}
}

32 changes: 32 additions & 0 deletions tests/mir-opt/gvn.meta_of_ref_to_slice.GVN.panic-unwind.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
- // MIR for `meta_of_ref_to_slice` before GVN
+ // MIR for `meta_of_ref_to_slice` after GVN

fn meta_of_ref_to_slice(_1: *const i32) -> usize {
debug x => _1;
let mut _0: usize;
let _2: *const [i32];
let mut _3: *const i32;
let mut _4: *const [i32];
scope 1 {
debug ptr => _2;
}

bb0: {
- StorageLive(_2);
+ nop;
StorageLive(_3);
_3 = _1;
- _2 = *const [i32] from (move _3, const 1_usize);
+ _2 = *const [i32] from (_1, const 1_usize);
StorageDead(_3);
StorageLive(_4);
_4 = _2;
- _0 = PtrMetadata(move _4);
+ _0 = const 1_usize;
StorageDead(_4);
- StorageDead(_2);
+ nop;
return;
}
}

Loading

0 comments on commit f000b42

Please sign in to comment.