Skip to content

Commit

Permalink
Rollup merge of #119699 - cjgillot:simplify-unreachable, r=oli-obk
Browse files Browse the repository at this point in the history
Merge dead bb pruning and unreachable bb deduplication.

Both routines share the same basic structure: iterate on all bbs to identify work, and then renumber bbs.

We can do both at once.
  • Loading branch information
GuillaumeGomez authored Jan 9, 2024
2 parents 72fdaf5 + 4071572 commit 9b90541
Show file tree
Hide file tree
Showing 19 changed files with 124 additions and 156 deletions.
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,7 @@ impl<'tcx> BasicBlockData<'tcx> {
}

/// Does the block have no statements and an unreachable terminator?
#[inline]
pub fn is_empty_unreachable(&self) -> bool {
self.statements.is_empty() && matches!(self.terminator().kind, TerminatorKind::Unreachable)
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_mir_transform/src/const_goto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl<'tcx> MirPass<'tcx> for ConstGoto {
// if we applied optimizations, we potentially have some cfg to cleanup to
// make it easier for further passes
if should_simplify {
simplify_cfg(tcx, body);
simplify_cfg(body);
simplify_locals(body, tcx);
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_mir_transform/src/deduplicate_blocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl<'tcx> MirPass<'tcx> for DeduplicateBlocks {
if has_opts_to_apply {
let mut opt_applier = OptApplier { tcx, duplicates };
opt_applier.visit_body(body);
simplify_cfg(tcx, body);
simplify_cfg(body);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_mir_transform/src/early_otherwise_branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
// Since this optimization adds new basic blocks and invalidates others,
// clean up the cfg to make it nicer for other passes
if should_cleanup {
simplify_cfg(tcx, body);
simplify_cfg(body);
}
}
}
Expand Down
5 changes: 2 additions & 3 deletions compiler/rustc_mir_transform/src/inline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use rustc_target::abi::FieldIdx;
use rustc_target::spec::abi::Abi;

use crate::cost_checker::CostChecker;
use crate::simplify::{remove_dead_blocks, CfgSimplifier};
use crate::simplify::simplify_cfg;
use crate::util;
use std::iter;
use std::ops::{Range, RangeFrom};
Expand Down Expand Up @@ -56,8 +56,7 @@ impl<'tcx> MirPass<'tcx> for Inline {
let _guard = span.enter();
if inline(tcx, body) {
debug!("running simplify cfg on {:?}", body.source);
CfgSimplifier::new(body).simplify();
remove_dead_blocks(body);
simplify_cfg(body);
deref_finder(tcx, body);
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_mir_transform/src/match_branches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
}

if should_cleanup {
simplify_cfg(tcx, body);
simplify_cfg(body);
}
}
}
2 changes: 1 addition & 1 deletion compiler/rustc_mir_transform/src/remove_unneeded_drops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl<'tcx> MirPass<'tcx> for RemoveUnneededDrops {
// if we applied optimizations, we potentially have some cfg to cleanup to
// make it easier for further passes
if should_simplify {
simplify_cfg(tcx, body);
simplify_cfg(body);
}
}
}
4 changes: 2 additions & 2 deletions compiler/rustc_mir_transform/src/separate_const_switch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ impl<'tcx> MirPass<'tcx> for SeparateConstSwitch {
sess.mir_opt_level() >= 2 && sess.opts.unstable_opts.unsound_mir_opts
}

fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
fn run_pass(&self, _: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
// If execution did something, applying a simplification layer
// helps later passes optimize the copy away.
if separate_const_switch(body) > 0 {
super::simplify::simplify_cfg(tcx, body);
super::simplify::simplify_cfg(body);
}
}
}
Expand Down
92 changes: 37 additions & 55 deletions compiler/rustc_mir_transform/src/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
//! naively generate still contains the `_a = ()` write in the unreachable block "after" the
//! return.
use rustc_data_structures::fx::FxIndexSet;
use rustc_index::{Idx, IndexSlice, IndexVec};
use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
use rustc_middle::mir::*;
Expand Down Expand Up @@ -62,9 +61,8 @@ impl SimplifyCfg {
}
}

pub fn simplify_cfg<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
pub(crate) fn simplify_cfg(body: &mut Body<'_>) {
CfgSimplifier::new(body).simplify();
remove_duplicate_unreachable_blocks(tcx, body);
remove_dead_blocks(body);

// FIXME: Should probably be moved into some kind of pass manager
Expand All @@ -76,9 +74,9 @@ impl<'tcx> MirPass<'tcx> for SimplifyCfg {
self.name()
}

fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
fn run_pass(&self, _: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
debug!("SimplifyCfg({:?}) - simplifying {:?}", self.name(), body.source);
simplify_cfg(tcx, body);
simplify_cfg(body);
}
}

Expand Down Expand Up @@ -289,55 +287,25 @@ pub fn simplify_duplicate_switch_targets(terminator: &mut Terminator<'_>) {
}
}

pub fn remove_duplicate_unreachable_blocks<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
struct OptApplier<'tcx> {
tcx: TyCtxt<'tcx>,
duplicates: FxIndexSet<BasicBlock>,
}

impl<'tcx> MutVisitor<'tcx> for OptApplier<'tcx> {
fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
}

fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
for target in terminator.successors_mut() {
// We don't have to check whether `target` is a cleanup block, because have
// entirely excluded cleanup blocks in building the set of duplicates.
if self.duplicates.contains(target) {
*target = self.duplicates[0];
}
}

simplify_duplicate_switch_targets(terminator);

self.super_terminator(terminator, location);
}
}
pub(crate) fn remove_dead_blocks(body: &mut Body<'_>) {
let should_deduplicate_unreachable = |bbdata: &BasicBlockData<'_>| {
// CfgSimplifier::simplify leaves behind some unreachable basic blocks without a
// terminator. Those blocks will be deleted by remove_dead_blocks, but we run just
// before then so we need to handle missing terminators.
// We also need to prevent confusing cleanup and non-cleanup blocks. In practice we
// don't emit empty unreachable cleanup blocks, so this simple check suffices.
bbdata.terminator.is_some() && bbdata.is_empty_unreachable() && !bbdata.is_cleanup
};

let unreachable_blocks = body
let reachable = traversal::reachable_as_bitset(body);
let empty_unreachable_blocks = body
.basic_blocks
.iter_enumerated()
.filter(|(_, bb)| {
// CfgSimplifier::simplify leaves behind some unreachable basic blocks without a
// terminator. Those blocks will be deleted by remove_dead_blocks, but we run just
// before then so we need to handle missing terminators.
// We also need to prevent confusing cleanup and non-cleanup blocks. In practice we
// don't emit empty unreachable cleanup blocks, so this simple check suffices.
bb.terminator.is_some() && bb.is_empty_unreachable() && !bb.is_cleanup
})
.map(|(block, _)| block)
.collect::<FxIndexSet<_>>();

if unreachable_blocks.len() > 1 {
OptApplier { tcx, duplicates: unreachable_blocks }.visit_body(body);
}
}
.filter(|(bb, bbdata)| should_deduplicate_unreachable(bbdata) && reachable.contains(*bb))
.count();

pub fn remove_dead_blocks(body: &mut Body<'_>) {
let reachable = traversal::reachable_as_bitset(body);
let num_blocks = body.basic_blocks.len();
if num_blocks == reachable.count() {
if num_blocks == reachable.count() && empty_unreachable_blocks <= 1 {
return;
}

Expand All @@ -346,14 +314,28 @@ pub fn remove_dead_blocks(body: &mut Body<'_>) {
let mut replacements: Vec<_> = (0..num_blocks).map(BasicBlock::new).collect();
let mut orig_index = 0;
let mut used_index = 0;
basic_blocks.raw.retain(|_| {
let keep = reachable.contains(BasicBlock::new(orig_index));
if keep {
replacements[orig_index] = BasicBlock::new(used_index);
used_index += 1;
let mut kept_unreachable = None;
basic_blocks.raw.retain(|bbdata| {
let orig_bb = BasicBlock::new(orig_index);
if !reachable.contains(orig_bb) {
orig_index += 1;
return false;
}

let used_bb = BasicBlock::new(used_index);
if should_deduplicate_unreachable(bbdata) {
let kept_unreachable = *kept_unreachable.get_or_insert(used_bb);
if kept_unreachable != used_bb {
replacements[orig_index] = kept_unreachable;
orig_index += 1;
return false;
}
}

replacements[orig_index] = used_bb;
used_index += 1;
orig_index += 1;
keep
true
});

for block in basic_blocks {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body@$DIR/async_await.rs:15:18: 18:2}>,

bb0: {
_39 = discriminant((*(_1.0: &mut {async fn body@$DIR/async_await.rs:15:18: 18:2})));
switchInt(move _39) -> [0: bb1, 1: bb29, 3: bb27, 4: bb28, otherwise: bb30];
switchInt(move _39) -> [0: bb1, 1: bb29, 3: bb27, 4: bb28, otherwise: bb9];
}

bb1: {
Expand Down Expand Up @@ -345,8 +345,4 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body@$DIR/async_await.rs:15:18: 18:2}>,
bb29: {
assert(const false, "`async fn` resumed after completion") -> [success: bb29, unwind unreachable];
}

bb30: {
unreachable;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
StorageLive(_11);
StorageLive(_12);
_10 = discriminant(_4);
switchInt(move _10) -> [0: bb8, 1: bb6, otherwise: bb7];
switchInt(move _10) -> [0: bb7, 1: bb6, otherwise: bb2];
}

bb1: {
Expand Down Expand Up @@ -114,20 +114,16 @@
_3 = ControlFlow::<Result<Infallible, i32>, i32>::Break(move _13);
StorageDead(_13);
- goto -> bb5;
+ goto -> bb9;
+ goto -> bb8;
}

bb7: {
unreachable;
}

bb8: {
_11 = move ((_4 as Ok).0: i32);
_3 = ControlFlow::<Result<Infallible, i32>, i32>::Continue(move _11);
goto -> bb5;
+ }
+
+ bb9: {
+ bb8: {
+ StorageDead(_12);
+ StorageDead(_11);
+ StorageDead(_10);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
StorageLive(_11);
StorageLive(_12);
_10 = discriminant(_4);
switchInt(move _10) -> [0: bb8, 1: bb6, otherwise: bb7];
switchInt(move _10) -> [0: bb7, 1: bb6, otherwise: bb2];
}

bb1: {
Expand Down Expand Up @@ -114,20 +114,16 @@
_3 = ControlFlow::<Result<Infallible, i32>, i32>::Break(move _13);
StorageDead(_13);
- goto -> bb5;
+ goto -> bb9;
+ goto -> bb8;
}

bb7: {
unreachable;
}

bb8: {
_11 = move ((_4 as Ok).0: i32);
_3 = ControlFlow::<Result<Infallible, i32>, i32>::Continue(move _11);
goto -> bb5;
+ }
+
+ bb9: {
+ bb8: {
+ StorageDead(_12);
+ StorageDead(_11);
+ StorageDead(_10);
Expand Down
8 changes: 3 additions & 5 deletions tests/mir-opt/jump_threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ fn identity(x: Result<i32, i32>) -> Result<i32, i32> {
// CHECK-LABEL: fn identity(
// CHECK: bb0: {
// CHECK: [[x:_.*]] = _1;
// CHECK: switchInt(move {{_.*}}) -> [0: bb8, 1: bb6, otherwise: bb7];
// CHECK: switchInt(move {{_.*}}) -> [0: bb7, 1: bb6, otherwise: bb2];
// CHECK: bb1: {
// CHECK: {{_.*}} = (([[controlflow:_.*]] as Continue).0: i32);
// CHECK: _0 = Result::<i32, i32>::Ok(
Expand All @@ -68,14 +68,12 @@ fn identity(x: Result<i32, i32>) -> Result<i32, i32> {
// CHECK: bb6: {
// CHECK: {{_.*}} = move (([[x]] as Err).0: i32);
// CHECK: [[controlflow]] = ControlFlow::<Result<Infallible, i32>, i32>::Break(
// CHECK: goto -> bb9;
// CHECK: goto -> bb8;
// CHECK: bb7: {
// CHECK: unreachable;
// CHECK: bb8: {
// CHECK: {{_.*}} = move (([[x]] as Ok).0: i32);
// CHECK: [[controlflow]] = ControlFlow::<Result<Infallible, i32>, i32>::Continue(
// CHECK: goto -> bb5;
// CHECK: bb9: {
// CHECK: bb8: {
// CHECK: goto -> bb3;
Ok(x?)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,16 @@
+ _2 = const Option::<Layout>::None;
StorageLive(_10);
- _10 = discriminant(_2);
- switchInt(move _10) -> [0: bb1, 1: bb3, otherwise: bb2];
- switchInt(move _10) -> [0: bb1, 1: bb2, otherwise: bb6];
+ _10 = const 0_isize;
+ switchInt(const 0_isize) -> [0: bb1, 1: bb3, otherwise: bb2];
+ switchInt(const 0_isize) -> [0: bb1, 1: bb2, otherwise: bb6];
}

bb1: {
_11 = core::panicking::panic(const "called `Option::unwrap()` on a `None` value") -> unwind unreachable;
}

bb2: {
unreachable;
}

bb3: {
- _1 = move ((_2 as Some).0: std::alloc::Layout);
+ _1 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum32) }};
StorageDead(_10);
Expand All @@ -79,18 +75,18 @@
StorageLive(_5);
StorageLive(_6);
_9 = const _;
- _6 = std::alloc::Global::alloc_impl(_9, _1, const false) -> [return: bb4, unwind unreachable];
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum32) }}, const false) -> [return: bb4, unwind unreachable];
- _6 = std::alloc::Global::alloc_impl(_9, _1, const false) -> [return: bb3, unwind unreachable];
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum32) }}, const false) -> [return: bb3, unwind unreachable];
}

bb4: {
bb3: {
StorageLive(_12);
StorageLive(_15);
_12 = discriminant(_6);
switchInt(move _12) -> [0: bb6, 1: bb5, otherwise: bb2];
switchInt(move _12) -> [0: bb5, 1: bb4, otherwise: bb6];
}

bb5: {
bb4: {
_15 = const "called `Result::unwrap()` on an `Err` value";
StorageLive(_16);
StorageLive(_17);
Expand All @@ -100,7 +96,7 @@
_14 = result::unwrap_failed(move _15, move _16) -> unwind unreachable;
}

bb6: {
bb5: {
_5 = move ((_6 as Ok).0: std::ptr::NonNull<[u8]>);
StorageDead(_15);
StorageDead(_12);
Expand All @@ -115,6 +111,10 @@
StorageDead(_3);
return;
}

bb6: {
unreachable;
}
}
+
+ ALLOC0 (size: 8, align: 4) {
Expand Down
Loading

0 comments on commit 9b90541

Please sign in to comment.