Skip to content

Commit

Permalink
Rollup merge of rust-lang#115796 - cjgillot:const-prop-rvalue, r=oli-obk
Browse files Browse the repository at this point in the history
Generate aggregate constants in DataflowConstProp.
  • Loading branch information
matthiaskrgr authored Oct 23, 2023
2 parents e2068cd + 8c1b039 commit 7d0a7de
Show file tree
Hide file tree
Showing 23 changed files with 785 additions and 158 deletions.
190 changes: 172 additions & 18 deletions compiler/rustc_mir_transform/src/dataflow_const_prop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@
//!
//! Currently, this pass only propagates scalar values.
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable};
use rustc_data_structures::fx::FxHashMap;
use rustc_hir::def::DefKind;
use rustc_middle::mir::interpret::{AllocId, ConstAllocation, InterpResult, Scalar};
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
use rustc_middle::mir::*;
use rustc_middle::ty::layout::TyAndLayout;
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_mir_dataflow::value_analysis::{
Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace,
};
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, Results, ResultsVisitor};
use rustc_span::def_id::DefId;
use rustc_span::DUMMY_SP;
use rustc_target::abi::{FieldIdx, VariantIdx};
use rustc_target::abi::{Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT};

use crate::const_prop::throw_machine_stop_str;
use crate::MirPass;

// These constants are somewhat random guesses and have not been optimized.
Expand Down Expand Up @@ -553,16 +554,151 @@ impl<'tcx, 'locals> Collector<'tcx, 'locals> {

fn try_make_constant(
&self,
ecx: &mut InterpCx<'tcx, 'tcx, DummyMachine>,
place: Place<'tcx>,
state: &State<FlatSet<Scalar>>,
map: &Map,
) -> Option<Const<'tcx>> {
let FlatSet::Elem(Scalar::Int(value)) = state.get(place.as_ref(), &map) else {
return None;
};
let ty = place.ty(self.local_decls, self.patch.tcx).ty;
Some(Const::Val(ConstValue::Scalar(value.into()), ty))
let layout = ecx.layout_of(ty).ok()?;

if layout.is_zst() {
return Some(Const::zero_sized(ty));
}

if layout.is_unsized() {
return None;
}

let place = map.find(place.as_ref())?;
if layout.abi.is_scalar()
&& let Some(value) = propagatable_scalar(place, state, map)
{
return Some(Const::Val(ConstValue::Scalar(value), ty));
}

if matches!(layout.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) {
let alloc_id = ecx
.intern_with_temp_alloc(layout, |ecx, dest| {
try_write_constant(ecx, dest, place, ty, state, map)
})
.ok()?;
return Some(Const::Val(ConstValue::Indirect { alloc_id, offset: Size::ZERO }, ty));
}

None
}
}

fn propagatable_scalar(
place: PlaceIndex,
state: &State<FlatSet<Scalar>>,
map: &Map,
) -> Option<Scalar> {
if let FlatSet::Elem(value) = state.get_idx(place, map) && value.try_to_int().is_ok() {
// Do not attempt to propagate pointers, as we may fail to preserve their identity.
Some(value)
} else {
None
}
}

#[instrument(level = "trace", skip(ecx, state, map))]
fn try_write_constant<'tcx>(
ecx: &mut InterpCx<'_, 'tcx, DummyMachine>,
dest: &PlaceTy<'tcx>,
place: PlaceIndex,
ty: Ty<'tcx>,
state: &State<FlatSet<Scalar>>,
map: &Map,
) -> InterpResult<'tcx> {
let layout = ecx.layout_of(ty)?;

// Fast path for ZSTs.
if layout.is_zst() {
return Ok(());
}

// Fast path for scalars.
if layout.abi.is_scalar()
&& let Some(value) = propagatable_scalar(place, state, map)
{
return ecx.write_immediate(Immediate::Scalar(value), dest);
}

match ty.kind() {
// ZSTs. Nothing to do.
ty::FnDef(..) => {}

// Those are scalars, must be handled above.
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => throw_machine_stop_str!("primitive type with provenance"),

ty::Tuple(elem_tys) => {
for (i, elem) in elem_tys.iter().enumerate() {
let Some(field) = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i))) else {
throw_machine_stop_str!("missing field in tuple")
};
let field_dest = ecx.project_field(dest, i)?;
try_write_constant(ecx, &field_dest, field, elem, state, map)?;
}
}

ty::Adt(def, args) => {
if def.is_union() {
throw_machine_stop_str!("cannot propagate unions")
}

let (variant_idx, variant_def, variant_place, variant_dest) = if def.is_enum() {
let Some(discr) = map.apply(place, TrackElem::Discriminant) else {
throw_machine_stop_str!("missing discriminant for enum")
};
let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else {
throw_machine_stop_str!("discriminant with provenance")
};
let discr_bits = discr.assert_bits(discr.size());
let Some((variant, _)) = def.discriminants(*ecx.tcx).find(|(_, var)| discr_bits == var.val) else {
throw_machine_stop_str!("illegal discriminant for enum")
};
let Some(variant_place) = map.apply(place, TrackElem::Variant(variant)) else {
throw_machine_stop_str!("missing variant for enum")
};
let variant_dest = ecx.project_downcast(dest, variant)?;
(variant, def.variant(variant), variant_place, variant_dest)
} else {
(FIRST_VARIANT, def.non_enum_variant(), place, dest.clone())
};

for (i, field) in variant_def.fields.iter_enumerated() {
let ty = field.ty(*ecx.tcx, args);
let Some(field) = map.apply(variant_place, TrackElem::Field(i)) else {
throw_machine_stop_str!("missing field in ADT")
};
let field_dest = ecx.project_field(&variant_dest, i.as_usize())?;
try_write_constant(ecx, &field_dest, field, ty, state, map)?;
}
ecx.write_discriminant(variant_idx, dest)?;
}

// Unsupported for now.
ty::Array(_, _)

// Do not attempt to support indirection in constants.
| ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_)

| ty::Never
| ty::Foreign(..)
| ty::Alias(..)
| ty::Param(_)
| ty::Bound(..)
| ty::Placeholder(..)
| ty::Closure(..)
| ty::Coroutine(..)
| ty::Dynamic(..) => throw_machine_stop_str!("unsupported type"),

ty::Error(_) | ty::Infer(..) | ty::CoroutineWitness(..) => bug!(),
}

Ok(())
}

impl<'mir, 'tcx>
Expand All @@ -580,8 +716,13 @@ impl<'mir, 'tcx>
) {
match &statement.kind {
StatementKind::Assign(box (_, rvalue)) => {
OperandCollector { state, visitor: self, map: &results.analysis.0.map }
.visit_rvalue(rvalue, location);
OperandCollector {
state,
visitor: self,
ecx: &mut results.analysis.0.ecx,
map: &results.analysis.0.map,
}
.visit_rvalue(rvalue, location);
}
_ => (),
}
Expand All @@ -599,7 +740,12 @@ impl<'mir, 'tcx>
// Don't overwrite the assignment if it already uses a constant (to keep the span).
}
StatementKind::Assign(box (place, _)) => {
if let Some(value) = self.try_make_constant(place, state, &results.analysis.0.map) {
if let Some(value) = self.try_make_constant(
&mut results.analysis.0.ecx,
place,
state,
&results.analysis.0.map,
) {
self.patch.assignments.insert(location, value);
}
}
Expand All @@ -614,8 +760,13 @@ impl<'mir, 'tcx>
terminator: &'mir Terminator<'tcx>,
location: Location,
) {
OperandCollector { state, visitor: self, map: &results.analysis.0.map }
.visit_terminator(terminator, location);
OperandCollector {
state,
visitor: self,
ecx: &mut results.analysis.0.ecx,
map: &results.analysis.0.map,
}
.visit_terminator(terminator, location);
}
}

Expand Down Expand Up @@ -670,6 +821,7 @@ impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> {
struct OperandCollector<'tcx, 'map, 'locals, 'a> {
state: &'a State<FlatSet<Scalar>>,
visitor: &'a mut Collector<'tcx, 'locals>,
ecx: &'map mut InterpCx<'tcx, 'tcx, DummyMachine>,
map: &'map Map,
}

Expand All @@ -682,15 +834,17 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> {
location: Location,
) {
if let PlaceElem::Index(local) = elem
&& let Some(value) = self.visitor.try_make_constant(local.into(), self.state, self.map)
&& let Some(value) = self.visitor.try_make_constant(self.ecx, local.into(), self.state, self.map)
{
self.visitor.patch.before_effect.insert((location, local.into()), value);
}
}

fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) {
if let Some(place) = operand.place() {
if let Some(value) = self.visitor.try_make_constant(place, self.state, self.map) {
if let Some(value) =
self.visitor.try_make_constant(self.ecx, place, self.state, self.map)
{
self.visitor.patch.before_effect.insert((location, place), value);
} else if !place.projection.is_empty() {
// Try to propagate into `Index` projections.
Expand All @@ -713,7 +867,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
}

fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool {
unimplemented!()
false
}

fn before_access_global(
Expand All @@ -725,13 +879,13 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
is_write: bool,
) -> InterpResult<'tcx> {
if is_write {
crate::const_prop::throw_machine_stop_str!("can't write to global");
throw_machine_stop_str!("can't write to global");
}

// If the static allocation is mutable, then we can't const prop it as its content
// might be different at runtime.
if alloc.inner().mutability.is_mut() {
crate::const_prop::throw_machine_stop_str!("can't access mutable globals in ConstProp");
throw_machine_stop_str!("can't access mutable globals in ConstProp");
}

Ok(())
Expand Down Expand Up @@ -781,7 +935,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
_left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>,
_right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>,
) -> interpret::InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)> {
crate::const_prop::throw_machine_stop_str!("can't do pointer arithmetic");
throw_machine_stop_str!("can't do pointer arithmetic");
}

fn expose_ptr(
Expand Down
9 changes: 7 additions & 2 deletions tests/mir-opt/const_debuginfo.main.ConstDebugInfo.diff
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
+ debug ((f: (bool, bool, u32)).2: u32) => const 123_u32;
let _10: std::option::Option<u16>;
scope 7 {
debug o => _10;
- debug o => _10;
+ debug o => const Option::<u16>::Some(99_u16);
let _17: u32;
let _18: u32;
scope 8 {
Expand Down Expand Up @@ -81,7 +82,7 @@
_15 = const false;
_16 = const 123_u32;
StorageLive(_10);
_10 = Option::<u16>::Some(const 99_u16);
_10 = const Option::<u16>::Some(99_u16);
_17 = const 32_u32;
_18 = const 32_u32;
StorageLive(_11);
Expand All @@ -97,3 +98,7 @@
}
}

ALLOC0 (size: 4, align: 2) {
01 00 63 00 │ ..c.
}

Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
- _6 = CheckedAdd(_4, _5);
- assert(!move (_6.1: bool), "attempt to compute `{} + {}`, which would overflow", move _4, move _5) -> [success: bb1, unwind unreachable];
+ _5 = const 2_i32;
+ _6 = CheckedAdd(const 1_i32, const 2_i32);
+ _6 = const (3_i32, false);
+ assert(!const false, "attempt to compute `{} + {}`, which would overflow", const 1_i32, const 2_i32) -> [success: bb1, unwind unreachable];
}

Expand All @@ -60,7 +60,7 @@
- _10 = CheckedAdd(_9, const 1_i32);
- assert(!move (_10.1: bool), "attempt to compute `{} + {}`, which would overflow", move _9, const 1_i32) -> [success: bb2, unwind unreachable];
+ _9 = const i32::MAX;
+ _10 = CheckedAdd(const i32::MAX, const 1_i32);
+ _10 = const (i32::MIN, true);
+ assert(!const true, "attempt to compute `{} + {}`, which would overflow", const i32::MAX, const 1_i32) -> [success: bb2, unwind unreachable];
}

Expand All @@ -76,5 +76,13 @@
StorageDead(_1);
return;
}
+ }
+
+ ALLOC0 (size: 8, align: 4) {
+ 00 00 00 80 01 __ __ __ │ .....░░░
+ }
+
+ ALLOC1 (size: 8, align: 4) {
+ 03 00 00 00 00 __ __ __ │ .....░░░
}

Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
- _6 = CheckedAdd(_4, _5);
- assert(!move (_6.1: bool), "attempt to compute `{} + {}`, which would overflow", move _4, move _5) -> [success: bb1, unwind continue];
+ _5 = const 2_i32;
+ _6 = CheckedAdd(const 1_i32, const 2_i32);
+ _6 = const (3_i32, false);
+ assert(!const false, "attempt to compute `{} + {}`, which would overflow", const 1_i32, const 2_i32) -> [success: bb1, unwind continue];
}

Expand All @@ -60,7 +60,7 @@
- _10 = CheckedAdd(_9, const 1_i32);
- assert(!move (_10.1: bool), "attempt to compute `{} + {}`, which would overflow", move _9, const 1_i32) -> [success: bb2, unwind continue];
+ _9 = const i32::MAX;
+ _10 = CheckedAdd(const i32::MAX, const 1_i32);
+ _10 = const (i32::MIN, true);
+ assert(!const true, "attempt to compute `{} + {}`, which would overflow", const i32::MAX, const 1_i32) -> [success: bb2, unwind continue];
}

Expand All @@ -76,5 +76,13 @@
StorageDead(_1);
return;
}
+ }
+
+ ALLOC0 (size: 8, align: 4) {
+ 00 00 00 80 01 __ __ __ │ .....░░░
+ }
+
+ ALLOC1 (size: 8, align: 4) {
+ 03 00 00 00 00 __ __ __ │ .....░░░
}

2 changes: 1 addition & 1 deletion tests/mir-opt/dataflow-const-prop/checked.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// skip-filecheck
// EMIT_MIR_FOR_EACH_PANIC_STRATEGY
// unit-test: DataflowConstProp
// compile-flags: -Coverflow-checks=on
// EMIT_MIR_FOR_EACH_PANIC_STRATEGY

// EMIT_MIR checked.main.DataflowConstProp.diff
#[allow(arithmetic_overflow)]
Expand Down
Loading

0 comments on commit 7d0a7de

Please sign in to comment.