Skip to content

Commit

Permalink
Update MIR with MirPatch in UninhabitedEnumBranching
Browse files Browse the repository at this point in the history
  • Loading branch information
DianQK committed Mar 8, 2024
1 parent 3d7f8b4 commit b5bd98d
Show file tree
Hide file tree
Showing 15 changed files with 155 additions and 165 deletions.
27 changes: 25 additions & 2 deletions compiler/rustc_middle/src/mir/patch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pub struct MirPatch<'tcx> {
resume_block: Option<BasicBlock>,
// Only for unreachable in cleanup path.
unreachable_cleanup_block: Option<BasicBlock>,
// Only for unreachable not in cleanup path.
unreachable_no_cleanup_block: Option<BasicBlock>,
// Cached block for UnwindTerminate (with reason)
terminate_block: Option<(BasicBlock, UnwindTerminateReason)>,
body_span: Span,
Expand All @@ -27,6 +29,7 @@ impl<'tcx> MirPatch<'tcx> {
next_local: body.local_decls.len(),
resume_block: None,
unreachable_cleanup_block: None,
unreachable_no_cleanup_block: None,
terminate_block: None,
body_span: body.span,
};
Expand All @@ -43,9 +46,12 @@ impl<'tcx> MirPatch<'tcx> {
// Check if we already have an unreachable block
if matches!(block.terminator().kind, TerminatorKind::Unreachable)
&& block.statements.is_empty()
&& block.is_cleanup
{
result.unreachable_cleanup_block = Some(bb);
if block.is_cleanup {
result.unreachable_cleanup_block = Some(bb);
} else {
result.unreachable_no_cleanup_block = Some(bb);
}
continue;
}

Expand Down Expand Up @@ -95,6 +101,23 @@ impl<'tcx> MirPatch<'tcx> {
bb
}

pub fn unreachable_no_cleanup_block(&mut self) -> BasicBlock {
if let Some(bb) = self.unreachable_no_cleanup_block {
return bb;
}

let bb = self.new_block(BasicBlockData {
statements: vec![],
terminator: Some(Terminator {
source_info: SourceInfo::outermost(self.body_span),
kind: TerminatorKind::Unreachable,
}),
is_cleanup: false,
});
self.unreachable_no_cleanup_block = Some(bb);
bb
}

pub fn terminate_block(&mut self, reason: UnwindTerminateReason) -> BasicBlock {
if let Some((cached_bb, cached_reason)) = self.terminate_block
&& reason == cached_reason
Expand Down
67 changes: 30 additions & 37 deletions compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
use crate::MirPass;
use rustc_data_structures::fx::FxHashSet;
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::{
BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, Terminator, TerminatorKind,
BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, TerminatorKind,
};
use rustc_middle::ty::layout::TyAndLayout;
use rustc_middle::ty::{Ty, TyCtxt};
Expand Down Expand Up @@ -77,8 +78,8 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
trace!("UninhabitedEnumBranching starting for {:?}", body.source);

let mut removable_switchs = Vec::new();
let mut otherwise_is_last_variant_switchs = Vec::new();
let mut unreachable_targets = Vec::new();
let mut patch = MirPatch::new(body);

for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
trace!("processing block {:?}", bb);
Expand Down Expand Up @@ -107,49 +108,41 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {

trace!("allowed_variants = {:?}", allowed_variants);

let terminator = bb_data.terminator();
let TerminatorKind::SwitchInt { targets, .. } = &terminator.kind else { bug!() };
unreachable_targets.clear();
let TerminatorKind::SwitchInt { targets, discr } = &bb_data.terminator().kind else {
bug!()
};

for (index, (val, _)) in targets.iter().enumerate() {
if !allowed_variants.remove(&val) {
removable_switchs.push((bb, index));
unreachable_targets.push(index);
}
}

if allowed_variants.is_empty() {
removable_switchs.push((bb, targets.iter().count()));
} else if allowed_variants.len() == 1
&& !body.basic_blocks[targets.otherwise()].is_empty_unreachable()
{
#[allow(rustc::potential_query_instability)]
let last_variant = *allowed_variants.iter().next().unwrap();
otherwise_is_last_variant_switchs.push((bb, last_variant));
}
}
let replace_otherwise_to_unreachable = allowed_variants.len() <= 1
&& !body.basic_blocks[targets.otherwise()].is_empty_unreachable();

for (bb, last_variant) in otherwise_is_last_variant_switchs {
let bb_data = &mut body.basic_blocks.as_mut()[bb];
let terminator = bb_data.terminator_mut();
let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() };
targets.add_target(last_variant, targets.otherwise());
removable_switchs.push((bb, targets.iter().count()));
}
if unreachable_targets.is_empty() && !replace_otherwise_to_unreachable {
continue;
}

if removable_switchs.is_empty() {
return;
let unreachable_block = patch.unreachable_no_cleanup_block();
let mut targets = targets.clone();
if replace_otherwise_to_unreachable {
let otherwise_is_last_variant = !allowed_variants.is_empty();
if otherwise_is_last_variant {
#[allow(rustc::potential_query_instability)]
let last_variant = *allowed_variants.iter().next().unwrap();
targets.add_target(last_variant, targets.otherwise());
}
unreachable_targets.push(targets.iter().count());
}
for index in unreachable_targets.iter() {
targets.all_targets_mut()[*index] = unreachable_block;
}
patch.patch_terminator(bb, TerminatorKind::SwitchInt { targets, discr: discr.clone() });
}

let new_block = BasicBlockData::new(Some(Terminator {
source_info: body.basic_blocks[removable_switchs[0].0].terminator().source_info,
kind: TerminatorKind::Unreachable,
}));
let unreachable_block = body.basic_blocks.as_mut().push(new_block);

for (bb, index) in removable_switchs {
let bb = &mut body.basic_blocks.as_mut()[bb];
let terminator = bb.terminator_mut();
let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() };
targets.all_targets_mut()[index] = unreachable_block;
}
patch.apply(body);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,20 @@
+ _2 = const Option::<Layout>::None;
StorageLive(_10);
- _10 = discriminant(_2);
- switchInt(move _10) -> [0: bb1, 1: bb2, otherwise: bb6];
- switchInt(move _10) -> [0: bb2, 1: bb3, otherwise: bb1];
+ _10 = const 0_isize;
+ switchInt(const 0_isize) -> [0: bb1, 1: bb2, otherwise: bb6];
+ switchInt(const 0_isize) -> [0: bb2, 1: bb3, otherwise: bb1];
}

bb1: {
_11 = option::unwrap_failed() -> unwind unreachable;
unreachable;
}

bb2: {
_11 = option::unwrap_failed() -> unwind unreachable;
}

bb3: {
- _1 = move ((_2 as Some).0: std::alloc::Layout);
+ _1 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum) }};
StorageDead(_10);
Expand All @@ -82,21 +86,21 @@
+ _7 = const {ALLOC1<imm>: &std::alloc::Global};
StorageLive(_8);
- _8 = _1;
- _6 = std::alloc::Global::alloc_impl(move _7, move _8, const false) -> [return: bb3, unwind unreachable];
- _6 = std::alloc::Global::alloc_impl(move _7, move _8, const false) -> [return: bb4, unwind unreachable];
+ _8 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum) }};
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum) }}, const false) -> [return: bb3, unwind unreachable];
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum) }}, const false) -> [return: bb4, unwind unreachable];
}

bb3: {
bb4: {
StorageDead(_8);
StorageDead(_7);
StorageLive(_12);
StorageLive(_15);
_12 = discriminant(_6);
switchInt(move _12) -> [0: bb5, 1: bb4, otherwise: bb6];
switchInt(move _12) -> [0: bb6, 1: bb5, otherwise: bb1];
}

bb4: {
bb5: {
_15 = const "called `Result::unwrap()` on an `Err` value";
StorageLive(_16);
StorageLive(_17);
Expand All @@ -106,7 +110,7 @@
_14 = result::unwrap_failed(move _15, move _16) -> unwind unreachable;
}

bb5: {
bb6: {
_5 = move ((_6 as Ok).0: std::ptr::NonNull<[u8]>);
StorageDead(_15);
StorageDead(_12);
Expand All @@ -127,10 +131,6 @@
+ nop;
return;
}

bb6: {
unreachable;
}
}
+
+ ALLOC0 (size: 8, align: 4) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
+ _2 = const Option::<Layout>::None;
StorageLive(_10);
- _10 = discriminant(_2);
- switchInt(move _10) -> [0: bb2, 1: bb3, otherwise: bb5];
- switchInt(move _10) -> [0: bb3, 1: bb4, otherwise: bb2];
+ _10 = const 0_isize;
+ switchInt(const 0_isize) -> [0: bb2, 1: bb3, otherwise: bb5];
+ switchInt(const 0_isize) -> [0: bb3, 1: bb4, otherwise: bb2];
}

bb1: {
Expand All @@ -68,10 +68,14 @@
}

bb2: {
_11 = option::unwrap_failed() -> unwind continue;
unreachable;
}

bb3: {
_11 = option::unwrap_failed() -> unwind continue;
}

bb4: {
- _1 = move ((_2 as Some).0: std::alloc::Layout);
+ _1 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum) }};
StorageDead(_10);
Expand All @@ -86,20 +90,16 @@
+ _7 = const {ALLOC1<imm>: &std::alloc::Global};
StorageLive(_8);
- _8 = _1;
- _6 = std::alloc::Global::alloc_impl(move _7, move _8, const false) -> [return: bb4, unwind continue];
- _6 = std::alloc::Global::alloc_impl(move _7, move _8, const false) -> [return: bb5, unwind continue];
+ _8 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum) }};
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum) }}, const false) -> [return: bb4, unwind continue];
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum) }}, const false) -> [return: bb5, unwind continue];
}

bb4: {
bb5: {
StorageDead(_8);
StorageDead(_7);
_5 = Result::<NonNull<[u8]>, std::alloc::AllocError>::unwrap(move _6) -> [return: bb1, unwind continue];
}

bb5: {
unreachable;
}
}
+
+ ALLOC0 (size: 8, align: 4) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,20 @@
+ _2 = const Option::<Layout>::None;
StorageLive(_10);
- _10 = discriminant(_2);
- switchInt(move _10) -> [0: bb1, 1: bb2, otherwise: bb6];
- switchInt(move _10) -> [0: bb2, 1: bb3, otherwise: bb1];
+ _10 = const 0_isize;
+ switchInt(const 0_isize) -> [0: bb1, 1: bb2, otherwise: bb6];
+ switchInt(const 0_isize) -> [0: bb2, 1: bb3, otherwise: bb1];
}

bb1: {
_11 = option::unwrap_failed() -> unwind unreachable;
unreachable;
}

bb2: {
_11 = option::unwrap_failed() -> unwind unreachable;
}

bb3: {
- _1 = move ((_2 as Some).0: std::alloc::Layout);
+ _1 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(8 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x0000000000000000): std::ptr::alignment::AlignmentEnum) }};
StorageDead(_10);
Expand All @@ -82,21 +86,21 @@
+ _7 = const {ALLOC1<imm>: &std::alloc::Global};
StorageLive(_8);
- _8 = _1;
- _6 = std::alloc::Global::alloc_impl(move _7, move _8, const false) -> [return: bb3, unwind unreachable];
- _6 = std::alloc::Global::alloc_impl(move _7, move _8, const false) -> [return: bb4, unwind unreachable];
+ _8 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(8 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x0000000000000000): std::ptr::alignment::AlignmentEnum) }};
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(8 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x0000000000000000): std::ptr::alignment::AlignmentEnum) }}, const false) -> [return: bb3, unwind unreachable];
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(8 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x0000000000000000): std::ptr::alignment::AlignmentEnum) }}, const false) -> [return: bb4, unwind unreachable];
}

bb3: {
bb4: {
StorageDead(_8);
StorageDead(_7);
StorageLive(_12);
StorageLive(_15);
_12 = discriminant(_6);
switchInt(move _12) -> [0: bb5, 1: bb4, otherwise: bb6];
switchInt(move _12) -> [0: bb6, 1: bb5, otherwise: bb1];
}

bb4: {
bb5: {
_15 = const "called `Result::unwrap()` on an `Err` value";
StorageLive(_16);
StorageLive(_17);
Expand All @@ -106,7 +110,7 @@
_14 = result::unwrap_failed(move _15, move _16) -> unwind unreachable;
}

bb5: {
bb6: {
_5 = move ((_6 as Ok).0: std::ptr::NonNull<[u8]>);
StorageDead(_15);
StorageDead(_12);
Expand All @@ -127,10 +131,6 @@
+ nop;
return;
}

bb6: {
unreachable;
}
}
+
+ ALLOC0 (size: 16, align: 8) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
+ _2 = const Option::<Layout>::None;
StorageLive(_10);
- _10 = discriminant(_2);
- switchInt(move _10) -> [0: bb2, 1: bb3, otherwise: bb5];
- switchInt(move _10) -> [0: bb3, 1: bb4, otherwise: bb2];
+ _10 = const 0_isize;
+ switchInt(const 0_isize) -> [0: bb2, 1: bb3, otherwise: bb5];
+ switchInt(const 0_isize) -> [0: bb3, 1: bb4, otherwise: bb2];
}

bb1: {
Expand All @@ -68,10 +68,14 @@
}

bb2: {
_11 = option::unwrap_failed() -> unwind continue;
unreachable;
}

bb3: {
_11 = option::unwrap_failed() -> unwind continue;
}

bb4: {
- _1 = move ((_2 as Some).0: std::alloc::Layout);
+ _1 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(8 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x0000000000000000): std::ptr::alignment::AlignmentEnum) }};
StorageDead(_10);
Expand All @@ -86,20 +90,16 @@
+ _7 = const {ALLOC1<imm>: &std::alloc::Global};
StorageLive(_8);
- _8 = _1;
- _6 = std::alloc::Global::alloc_impl(move _7, move _8, const false) -> [return: bb4, unwind continue];
- _6 = std::alloc::Global::alloc_impl(move _7, move _8, const false) -> [return: bb5, unwind continue];
+ _8 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(8 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x0000000000000000): std::ptr::alignment::AlignmentEnum) }};
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(8 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x0000000000000000): std::ptr::alignment::AlignmentEnum) }}, const false) -> [return: bb4, unwind continue];
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(8 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x0000000000000000): std::ptr::alignment::AlignmentEnum) }}, const false) -> [return: bb5, unwind continue];
}

bb4: {
bb5: {
StorageDead(_8);
StorageDead(_7);
_5 = Result::<NonNull<[u8]>, std::alloc::AllocError>::unwrap(move _6) -> [return: bb1, unwind continue];
}

bb5: {
unreachable;
}
}
+
+ ALLOC0 (size: 16, align: 8) {
Expand Down
Loading

0 comments on commit b5bd98d

Please sign in to comment.