From 4c0dd893098d2e10fb46fadf00fdce8176a31ad4 Mon Sep 17 00:00:00 2001 From: Camille GILLOT Date: Sat, 20 May 2023 08:54:16 +0000 Subject: [PATCH] Simplify code flow. --- compiler/rustc_mir_transform/src/lib.rs | 1 + .../rustc_mir_transform/src/promote_consts.rs | 378 ++++++++---------- 2 files changed, 157 insertions(+), 222 deletions(-) diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index f099fdde78a4..4728348295fc 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -1,6 +1,7 @@ #![allow(rustc::potential_query_instability)] #![deny(rustc::untranslatable_diagnostic)] #![deny(rustc::diagnostic_outside_of_impl)] +#![feature(assert_matches)] #![feature(box_patterns)] #![feature(cow_is_borrowed)] #![feature(decl_macro)] diff --git a/compiler/rustc_mir_transform/src/promote_consts.rs b/compiler/rustc_mir_transform/src/promote_consts.rs index 02ea5ddb99d0..b7564f0b9ab7 100644 --- a/compiler/rustc_mir_transform/src/promote_consts.rs +++ b/compiler/rustc_mir_transform/src/promote_consts.rs @@ -12,6 +12,7 @@ //! initialization and can otherwise silence errors, if //! move analysis runs after promotion on broken MIR. +use either::{Left, Right}; use rustc_hir as hir; use rustc_middle::mir; use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor}; @@ -22,6 +23,7 @@ use rustc_span::Span; use rustc_index::{Idx, IndexSlice, IndexVec}; +use std::assert_matches::assert_matches; use std::cell::Cell; use std::{cmp, iter, mem}; @@ -116,41 +118,38 @@ impl<'tcx> Visitor<'tcx> for Collector<'_, 'tcx> { let temp = &mut self.temps[index]; debug!("visit_local: temp={:?}", temp); - if *temp == TempState::Undefined { - match context { + *temp = match *temp { + TempState::Undefined => match context { PlaceContext::MutatingUse(MutatingUseContext::Store) | PlaceContext::MutatingUse(MutatingUseContext::Call) => { - *temp = TempState::Defined { location, uses: 0, valid: Err(()) }; + TempState::Defined { location, uses: 0, valid: Err(()) } + } + _ => TempState::Unpromotable, + }, + TempState::Defined { ref mut uses, .. } => { + // We always allow borrows, even mutable ones, as we need + // to promote mutable borrows of some ZSTs e.g., `&mut []`. + let allowed_use = match context { + PlaceContext::MutatingUse(MutatingUseContext::Borrow) + | PlaceContext::NonMutatingUse(_) => true, + PlaceContext::MutatingUse(_) | PlaceContext::NonUse(_) => false, + }; + debug!("visit_local: allowed_use={:?}", allowed_use); + if allowed_use { + *uses += 1; return; } - _ => { /* mark as unpromotable below */ } + TempState::Unpromotable } - } else if let TempState::Defined { uses, .. } = temp { - // We always allow borrows, even mutable ones, as we need - // to promote mutable borrows of some ZSTs e.g., `&mut []`. - let allowed_use = match context { - PlaceContext::MutatingUse(MutatingUseContext::Borrow) - | PlaceContext::NonMutatingUse(_) => true, - PlaceContext::MutatingUse(_) | PlaceContext::NonUse(_) => false, - }; - debug!("visit_local: allowed_use={:?}", allowed_use); - if allowed_use { - *uses += 1; - return; - } - /* mark as unpromotable below */ - } - *temp = TempState::Unpromotable; + _ => TempState::Unpromotable, + }; } fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { self.super_rvalue(rvalue, location); - match *rvalue { - Rvalue::Ref(..) => { - self.candidates.push(Candidate { location }); - } - _ => {} + if let Rvalue::Ref(..) = *rvalue { + self.candidates.push(Candidate { location }); } } } @@ -189,230 +188,165 @@ struct Unpromotable; impl<'tcx> Validator<'_, 'tcx> { fn validate_candidate(&mut self, candidate: Candidate) -> Result<(), Unpromotable> { - let loc = candidate.location; - let statement = &self.body[loc.block].statements[loc.statement_index]; - match &statement.kind { - StatementKind::Assign(box (_, Rvalue::Ref(_, kind, place))) => { - // We can only promote interior borrows of promotable temps (non-temps - // don't get promoted anyway). - self.validate_local(place.local)?; - - // The reference operation itself must be promotable. - // (Needs to come after `validate_local` to avoid ICEs.) - self.validate_ref(*kind, place)?; + let Left(statement) = self.body.stmt_at(candidate.location) else { bug!() }; + let Some((_, Rvalue::Ref(_, kind, place))) = statement.kind.as_assign() else { bug!() }; - // We do not check all the projections (they do not get promoted anyway), - // but we do stay away from promoting anything involving a dereference. - if place.projection.contains(&ProjectionElem::Deref) { - return Err(Unpromotable); - } + // We can only promote interior borrows of promotable temps (non-temps + // don't get promoted anyway). + self.validate_local(place.local)?; - Ok(()) - } - _ => bug!(), + // The reference operation itself must be promotable. + // (Needs to come after `validate_local` to avoid ICEs.) + self.validate_ref(*kind, place)?; + + // We do not check all the projections (they do not get promoted anyway), + // but we do stay away from promoting anything involving a dereference. + if place.projection.contains(&ProjectionElem::Deref) { + return Err(Unpromotable); } + + Ok(()) } // FIXME(eddyb) maybe cache this? fn qualif_local(&mut self, local: Local) -> bool { - if let TempState::Defined { location: loc, .. } = self.temps[local] { - let num_stmts = self.body[loc.block].statements.len(); - - if loc.statement_index < num_stmts { - let statement = &self.body[loc.block].statements[loc.statement_index]; - match &statement.kind { - StatementKind::Assign(box (_, rhs)) => qualifs::in_rvalue::( - self.ccx, - &mut |l| self.qualif_local::(l), - rhs, - ), - _ => { + let TempState::Defined { location: loc, .. } = self.temps[local] else { + return false; + }; + + let stmt_or_term = self.body.stmt_at(loc); + match stmt_or_term { + Left(statement) => { + let Some((_, rhs)) = statement.kind.as_assign() else { + span_bug!(statement.source_info.span, "{:?} is not an assignment", statement) + }; + qualifs::in_rvalue::(self.ccx, &mut |l| self.qualif_local::(l), rhs) + } + Right(terminator) => { + assert_matches!(terminator.kind, TerminatorKind::Call { .. }); + let return_ty = self.body.local_decls[local].ty; + Q::in_any_value_of_ty(self.ccx, return_ty) + } + } + } + + fn validate_local(&mut self, local: Local) -> Result<(), Unpromotable> { + let TempState::Defined { location: loc, uses, valid } = self.temps[local] else { + return Err(Unpromotable); + }; + + // We cannot promote things that need dropping, since the promoted value would not get + // dropped. + if self.qualif_local::(local) { + return Err(Unpromotable); + } + + if valid.is_ok() { + return Ok(()); + } + + let ok = { + let stmt_or_term = self.body.stmt_at(loc); + match stmt_or_term { + Left(statement) => { + let Some((_, rhs)) = statement.kind.as_assign() else { span_bug!( statement.source_info.span, "{:?} is not an assignment", statement - ); - } + ) + }; + self.validate_rvalue(rhs) } - } else { - let terminator = self.body[loc.block].terminator(); - match &terminator.kind { - TerminatorKind::Call { .. } => { - let return_ty = self.body.local_decls[local].ty; - Q::in_any_value_of_ty(self.ccx, return_ty) - } + Right(terminator) => match &terminator.kind { + TerminatorKind::Call { func, args, .. } => self.validate_call(func, args), + TerminatorKind::Yield { .. } => Err(Unpromotable), kind => { span_bug!(terminator.source_info.span, "{:?} not promotable", kind); } - } + }, } - } else { - false - } - } + }; - fn validate_local(&mut self, local: Local) -> Result<(), Unpromotable> { - if let TempState::Defined { location: loc, uses, valid } = self.temps[local] { - // We cannot promote things that need dropping, since the promoted value - // would not get dropped. - if self.qualif_local::(local) { - return Err(Unpromotable); - } - valid.or_else(|_| { - let ok = { - let block = &self.body[loc.block]; - let num_stmts = block.statements.len(); - - if loc.statement_index < num_stmts { - let statement = &block.statements[loc.statement_index]; - match &statement.kind { - StatementKind::Assign(box (_, rhs)) => self.validate_rvalue(rhs), - _ => { - span_bug!( - statement.source_info.span, - "{:?} is not an assignment", - statement - ); - } - } - } else { - let terminator = block.terminator(); - match &terminator.kind { - TerminatorKind::Call { func, args, .. } => { - self.validate_call(func, args) - } - TerminatorKind::Yield { .. } => Err(Unpromotable), - kind => { - span_bug!(terminator.source_info.span, "{:?} not promotable", kind); - } - } - } - }; - self.temps[local] = match ok { - Ok(()) => TempState::Defined { location: loc, uses, valid: Ok(()) }, - Err(_) => TempState::Unpromotable, - }; - ok - }) - } else { - Err(Unpromotable) - } + self.temps[local] = match ok { + Ok(()) => TempState::Defined { location: loc, uses, valid: Ok(()) }, + Err(_) => TempState::Unpromotable, + }; + + ok } fn validate_place(&mut self, place: PlaceRef<'tcx>) -> Result<(), Unpromotable> { - match place.last_projection() { - None => self.validate_local(place.local), - Some((place_base, elem)) => { - // Validate topmost projection, then recurse. - match elem { - ProjectionElem::Deref => { - let mut promotable = false; - // When a static is used by-value, that gets desugared to `*STATIC_ADDR`, - // and we need to be able to promote this. So check if this deref matches - // that specific pattern. - - // We need to make sure this is a `Deref` of a local with no further projections. - // Discussion can be found at - // https://github.com/rust-lang/rust/pull/74945#discussion_r463063247 - if let Some(local) = place_base.as_local() { - if let TempState::Defined { location, .. } = self.temps[local] { - let def_stmt = self.body[location.block] - .statements - .get(location.statement_index); - if let Some(Statement { - kind: - StatementKind::Assign(box ( - _, - Rvalue::Use(Operand::Constant(c)), - )), - .. - }) = def_stmt - { - if let Some(did) = c.check_static_ptr(self.tcx) { - // Evaluating a promoted may not read statics except if it got - // promoted from a static (this is a CTFE check). So we - // can only promote static accesses inside statics. - if let Some(hir::ConstContext::Static(..)) = self.const_kind - { - if !self.tcx.is_thread_local_static(did) { - promotable = true; - } - } - } - } - } - } - if !promotable { - return Err(Unpromotable); - } - } - ProjectionElem::OpaqueCast(..) | ProjectionElem::Downcast(..) => { - return Err(Unpromotable); - } + let Some((place_base, elem)) = place.last_projection() else { + return self.validate_local(place.local); + }; - ProjectionElem::ConstantIndex { .. } - | ProjectionElem::Subtype(_) - | ProjectionElem::Subslice { .. } => {} - - ProjectionElem::Index(local) => { - let mut promotable = false; - // Only accept if we can predict the index and are indexing an array. - let val = if let TempState::Defined { location: loc, .. } = - self.temps[local] - { - let block = &self.body[loc.block]; - if loc.statement_index < block.statements.len() { - let statement = &block.statements[loc.statement_index]; - match &statement.kind { - StatementKind::Assign(box ( - _, - Rvalue::Use(Operand::Constant(c)), - )) => c.const_.try_eval_target_usize(self.tcx, self.param_env), - _ => None, - } - } else { - None - } - } else { - None - }; - if let Some(idx) = val { - // Determine the type of the thing we are indexing. - let ty = place_base.ty(self.body, self.tcx).ty; - match ty.kind() { - ty::Array(_, len) => { - // It's an array; determine its length. - if let Some(len) = - len.try_eval_target_usize(self.tcx, self.param_env) - { - // If the index is in-bounds, go ahead. - if idx < len { - promotable = true; - } - } - } - _ => {} - } - } - if !promotable { - return Err(Unpromotable); - } + // Validate topmost projection, then recurse. + match elem { + // Recurse directly. + ProjectionElem::ConstantIndex { .. } + | ProjectionElem::Subtype(_) + | ProjectionElem::Subslice { .. } => {} - self.validate_local(local)?; - } + // Never recurse. + ProjectionElem::OpaqueCast(..) | ProjectionElem::Downcast(..) => { + return Err(Unpromotable); + } - ProjectionElem::Field(..) => { - let base_ty = place_base.ty(self.body, self.tcx).ty; - if base_ty.is_union() { - // No promotion of union field accesses. - return Err(Unpromotable); - } - } + ProjectionElem::Deref => { + // When a static is used by-value, that gets desugared to `*STATIC_ADDR`, + // and we need to be able to promote this. So check if this deref matches + // that specific pattern. + + // We need to make sure this is a `Deref` of a local with no further projections. + // Discussion can be found at + // https://github.com/rust-lang/rust/pull/74945#discussion_r463063247 + if let Some(local) = place_base.as_local() + && let TempState::Defined { location, .. } = self.temps[local] + && let Left(def_stmt) = self.body.stmt_at(location) + && let Some((_, Rvalue::Use(Operand::Constant(c)))) = def_stmt.kind.as_assign() + && let Some(did) = c.check_static_ptr(self.tcx) + // Evaluating a promoted may not read statics except if it got + // promoted from a static (this is a CTFE check). So we + // can only promote static accesses inside statics. + && let Some(hir::ConstContext::Static(..)) = self.const_kind + && !self.tcx.is_thread_local_static(did) + { + // Recurse. + } else { + return Err(Unpromotable); } + } + ProjectionElem::Index(local) => { + // Only accept if we can predict the index and are indexing an array. + if let TempState::Defined { location: loc, .. } = self.temps[local] + && let Left(statement) = self.body.stmt_at(loc) + && let Some((_, Rvalue::Use(Operand::Constant(c)))) = statement.kind.as_assign() + && let Some(idx) = c.const_.try_eval_target_usize(self.tcx, self.param_env) + // Determine the type of the thing we are indexing. + && let ty::Array(_, len) = place_base.ty(self.body, self.tcx).ty.kind() + // It's an array; determine its length. + && let Some(len) = len.try_eval_target_usize(self.tcx, self.param_env) + // If the index is in-bounds, go ahead. + && idx < len + { + self.validate_local(local)?; + // Recurse. + } else { + return Err(Unpromotable); + } + } - self.validate_place(place_base) + ProjectionElem::Field(..) => { + let base_ty = place_base.ty(self.body, self.tcx).ty; + if base_ty.is_union() { + // No promotion of union field accesses. + return Err(Unpromotable); + } } } + + self.validate_place(place_base) } fn validate_operand(&mut self, operand: &Operand<'tcx>) -> Result<(), Unpromotable> {