diff --git a/kani-compiler/src/args.rs b/kani-compiler/src/args.rs index 69a82b61e19d..fcc346bd745f 100644 --- a/kani-compiler/src/args.rs +++ b/kani-compiler/src/args.rs @@ -71,4 +71,15 @@ pub struct Arguments { #[clap(long)] /// A legacy flag that is now ignored. goto_c: bool, + /// Enable specific checks. + #[clap(long)] + pub ub_check: Vec, +} + +#[derive(Debug, Clone, Copy, AsRefStr, EnumString, VariantNames, PartialEq, Eq)] +#[strum(serialize_all = "snake_case")] +pub enum ExtraChecks { + /// Check that produced values are valid except for uninitialized values. + /// See https://github.com/model-checking/kani/issues/920. + Validity, } diff --git a/kani-compiler/src/codegen_cprover_gotoc/codegen/function.rs b/kani-compiler/src/codegen_cprover_gotoc/codegen/function.rs index 7cf4b979f407..d55696bfdc87 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/codegen/function.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/codegen/function.rs @@ -60,7 +60,7 @@ impl<'tcx> GotocCtx<'tcx> { debug!("Double codegen of {:?}", old_sym); } else { assert!(old_sym.is_function()); - let body = instance.body().unwrap(); + let body = self.transformer.body(self.tcx, instance); self.set_current_fn(instance, &body); self.print_instance(instance, &body); self.codegen_function_prelude(&body); @@ -201,7 +201,7 @@ impl<'tcx> GotocCtx<'tcx> { pub fn declare_function(&mut self, instance: Instance) { debug!("declaring {}; {:?}", instance.name(), instance); - let body = instance.body().unwrap(); + let body = self.transformer.body(self.tcx, instance); self.set_current_fn(instance, &body); debug!(krate=?instance.def.krate(), is_std=self.current_fn().is_std(), "declare_function"); self.ensure(&self.symbol_name_stable(instance), |ctx, fname| { diff --git a/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs b/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs index cad8bcbfb9ce..d3b4ae8473de 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs @@ -12,6 +12,7 @@ use crate::kani_middle::provide; use crate::kani_middle::reachability::{ collect_reachable_items, filter_const_crate_items, filter_crate_items, }; +use crate::kani_middle::transform::BodyTransformation; use crate::kani_middle::{check_reachable_items, dump_mir_items}; use crate::kani_queries::QueryDb; use cbmc::goto_program::Location; @@ -86,16 +87,18 @@ impl GotocCodegenBackend { symtab_goto: &Path, machine_model: &MachineModel, check_contract: Option, + mut transformer: BodyTransformation, ) -> (GotocCtx<'tcx>, Vec, Option) { let items = with_timer( - || collect_reachable_items(tcx, starting_items), + || collect_reachable_items(tcx, &mut transformer, starting_items), "codegen reachability analysis", ); dump_mir_items(tcx, &items, &symtab_goto.with_extension("kani.mir")); // Follow rustc naming convention (cx is abbrev for context). // https://rustc-dev-guide.rust-lang.org/conventions.html#naming-conventions - let mut gcx = GotocCtx::new(tcx, (*self.queries.lock().unwrap()).clone(), machine_model); + let mut gcx = + GotocCtx::new(tcx, (*self.queries.lock().unwrap()).clone(), machine_model, transformer); check_reachable_items(gcx.tcx, &gcx.queries, &items); let contract_info = with_timer( @@ -227,6 +230,7 @@ impl CodegenBackend for GotocCodegenBackend { // - None: Don't generate code. This is used to compile dependencies. let base_filename = tcx.output_filenames(()).output_path(OutputType::Object); let reachability = queries.args().reachability_analysis; + let mut transformer = BodyTransformation::new(&queries, tcx); let mut results = GotoCodegenResults::new(tcx, reachability); match reachability { ReachabilityType::Harnesses => { @@ -248,8 +252,9 @@ impl CodegenBackend for GotocCodegenBackend { model_path, &results.machine_model, contract_metadata, + transformer, ); - results.extend(gcx, items, None); + transformer = results.extend(gcx, items, None); if let Some(assigns_contract) = contract_info { self.queries.lock().unwrap().register_assigns_contract( canonical_mangled_name(harness).intern(), @@ -263,7 +268,7 @@ impl CodegenBackend for GotocCodegenBackend { // test closure that we want to execute // TODO: Refactor this code so we can guarantee that the pair (test_fn, test_desc) actually match. let mut descriptions = vec![]; - let harnesses = filter_const_crate_items(tcx, |_, item| { + let harnesses = filter_const_crate_items(tcx, &mut transformer, |_, item| { if is_test_harness_description(tcx, item.def) { descriptions.push(item.def); true @@ -282,6 +287,7 @@ impl CodegenBackend for GotocCodegenBackend { &model_path, &results.machine_model, Default::default(), + transformer, ); results.extend(gcx, items, None); @@ -319,9 +325,10 @@ impl CodegenBackend for GotocCodegenBackend { &model_path, &results.machine_model, Default::default(), + transformer, ); assert!(contract_info.is_none()); - results.extend(gcx, items, None); + let _ = results.extend(gcx, items, None); } } @@ -613,12 +620,18 @@ impl GotoCodegenResults { } } - fn extend(&mut self, gcx: GotocCtx, items: Vec, metadata: Option) { + fn extend( + &mut self, + gcx: GotocCtx, + items: Vec, + metadata: Option, + ) -> BodyTransformation { let mut items = items; self.harnesses.extend(metadata); self.concurrent_constructs.extend(gcx.concurrent_constructs); self.unsupported_constructs.extend(gcx.unsupported_constructs); self.items.append(&mut items); + gcx.transformer } /// Prints a report at the end of the compilation. diff --git a/kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs b/kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs index 10ee6a876588..a25e502aa574 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs @@ -18,6 +18,7 @@ use super::vtable_ctx::VtableCtx; use crate::codegen_cprover_gotoc::overrides::{fn_hooks, GotocHooks}; use crate::codegen_cprover_gotoc::utils::full_crate_name; use crate::codegen_cprover_gotoc::UnsupportedConstructs; +use crate::kani_middle::transform::BodyTransformation; use crate::kani_queries::QueryDb; use cbmc::goto_program::{DatatypeComponent, Expr, Location, Stmt, Symbol, SymbolTable, Type}; use cbmc::utils::aggr_tag; @@ -70,6 +71,8 @@ pub struct GotocCtx<'tcx> { /// We collect them and print one warning at the end if not empty instead of printing one /// warning at each occurrence. pub concurrent_constructs: UnsupportedConstructs, + /// The body transformation agent. + pub transformer: BodyTransformation, } /// Constructor @@ -78,6 +81,7 @@ impl<'tcx> GotocCtx<'tcx> { tcx: TyCtxt<'tcx>, queries: QueryDb, machine_model: &MachineModel, + transformer: BodyTransformation, ) -> GotocCtx<'tcx> { let fhks = fn_hooks(); let symbol_table = SymbolTable::new(machine_model.clone()); @@ -99,6 +103,7 @@ impl<'tcx> GotocCtx<'tcx> { global_checks_count: 0, unsupported_constructs: FxHashMap::default(), concurrent_constructs: FxHashMap::default(), + transformer, } } } diff --git a/kani-compiler/src/codegen_cprover_gotoc/overrides/hooks.rs b/kani-compiler/src/codegen_cprover_gotoc/overrides/hooks.rs index 2fcbaa36fb44..18cc44b7b20d 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/overrides/hooks.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/overrides/hooks.rs @@ -10,6 +10,7 @@ use crate::codegen_cprover_gotoc::codegen::{bb_label, PropertyClass}; use crate::codegen_cprover_gotoc::GotocCtx; +use crate::kani_middle::attributes::matches_diagnostic as matches_function; use crate::unwrap_or_return_codegen_unimplemented_stmt; use cbmc::goto_program::{BuiltinFn, Expr, Location, Stmt, Type}; use rustc_middle::ty::TyCtxt; @@ -35,17 +36,6 @@ pub trait GotocHook { ) -> Stmt; } -fn matches_function(tcx: TyCtxt, instance: Instance, attr_name: &str) -> bool { - let attr_sym = rustc_span::symbol::Symbol::intern(attr_name); - if let Some(attr_id) = tcx.all_diagnostic_items(()).name_to_id.get(&attr_sym) { - if rustc_internal::internal(tcx, instance.def.def_id()) == *attr_id { - debug!("matched: {:?} {:?}", attr_id, attr_sym); - return true; - } - } - false -} - /// A hook for Kani's `cover` function (declared in `library/kani/src/lib.rs`). /// The function takes two arguments: a condition expression (bool) and a /// message (&'static str). @@ -57,7 +47,7 @@ fn matches_function(tcx: TyCtxt, instance: Instance, attr_name: &str) -> bool { struct Cover; impl GotocHook for Cover { fn hook_applies(&self, tcx: TyCtxt, instance: Instance) -> bool { - matches_function(tcx, instance, "KaniCover") + matches_function(tcx, instance.def, "KaniCover") } fn handle( @@ -92,7 +82,7 @@ impl GotocHook for Cover { struct Assume; impl GotocHook for Assume { fn hook_applies(&self, tcx: TyCtxt, instance: Instance) -> bool { - matches_function(tcx, instance, "KaniAssume") + matches_function(tcx, instance.def, "KaniAssume") } fn handle( @@ -116,7 +106,7 @@ impl GotocHook for Assume { struct Assert; impl GotocHook for Assert { fn hook_applies(&self, tcx: TyCtxt, instance: Instance) -> bool { - matches_function(tcx, instance, "KaniAssert") + matches_function(tcx, instance.def, "KaniAssert") } fn handle( @@ -157,7 +147,7 @@ struct Nondet; impl GotocHook for Nondet { fn hook_applies(&self, tcx: TyCtxt, instance: Instance) -> bool { - matches_function(tcx, instance, "KaniAnyRaw") + matches_function(tcx, instance.def, "KaniAnyRaw") } fn handle( @@ -201,7 +191,7 @@ impl GotocHook for Panic { || tcx.has_attr(def_id, rustc_span::sym::rustc_const_panic_str) || Some(def_id) == tcx.lang_items().panic_fmt() || Some(def_id) == tcx.lang_items().begin_panic_fn() - || matches_function(tcx, instance, "KaniPanic") + || matches_function(tcx, instance.def, "KaniPanic") } fn handle( @@ -221,7 +211,7 @@ impl GotocHook for Panic { struct IsReadOk; impl GotocHook for IsReadOk { fn hook_applies(&self, tcx: TyCtxt, instance: Instance) -> bool { - matches_function(tcx, instance, "KaniIsReadOk") + matches_function(tcx, instance.def, "KaniIsReadOk") } fn handle( @@ -365,7 +355,7 @@ struct UntrackedDeref; impl GotocHook for UntrackedDeref { fn hook_applies(&self, tcx: TyCtxt, instance: Instance) -> bool { - matches_function(tcx, instance, "KaniUntrackedDeref") + matches_function(tcx, instance.def, "KaniUntrackedDeref") } fn handle( diff --git a/kani-compiler/src/kani_middle/attributes.rs b/kani-compiler/src/kani_middle/attributes.rs index df494d7daa3c..979877199dde 100644 --- a/kani-compiler/src/kani_middle/attributes.rs +++ b/kani-compiler/src/kani_middle/attributes.rs @@ -1037,3 +1037,14 @@ fn attr_kind(tcx: TyCtxt, attr: &Attribute) -> Option { _ => None, } } + +pub fn matches_diagnostic(tcx: TyCtxt, def: T, attr_name: &str) -> bool { + let attr_sym = rustc_span::symbol::Symbol::intern(attr_name); + if let Some(attr_id) = tcx.all_diagnostic_items(()).name_to_id.get(&attr_sym) { + if rustc_internal::internal(tcx, def.def_id()) == *attr_id { + debug!("matched: {:?} {:?}", attr_id, attr_sym); + return true; + } + } + false +} diff --git a/kani-compiler/src/kani_middle/mod.rs b/kani-compiler/src/kani_middle/mod.rs index e65befc9624e..4cd2d4a357d7 100644 --- a/kani-compiler/src/kani_middle/mod.rs +++ b/kani-compiler/src/kani_middle/mod.rs @@ -23,7 +23,7 @@ use rustc_target::abi::call::FnAbi; use rustc_target::abi::{HasDataLayout, TargetDataLayout}; use stable_mir::mir::mono::{Instance, InstanceKind, MonoItem}; use stable_mir::mir::pretty::pretty_ty; -use stable_mir::ty::{BoundVariableKind, RigidTy, Span as SpanStable, Ty, TyKind}; +use stable_mir::ty::{BoundVariableKind, FnDef, RigidTy, Span as SpanStable, Ty, TyKind}; use stable_mir::visitor::{Visitable, Visitor as TypeVisitor}; use stable_mir::{CrateDef, DefId}; use std::fs::File; @@ -41,6 +41,7 @@ pub mod provide; pub mod reachability; pub mod resolve; pub mod stubbing; +pub mod transform; /// Check that all crate items are supported and there's no misconfiguration. /// This method will exhaustively print any error / warning and it will abort at the end if any @@ -316,3 +317,18 @@ impl<'tcx> FnAbiOfHelpers<'tcx> for CompilerHelpers<'tcx> { } } } + +/// Find an instance of a function from the given crate that has been annotated with `diagnostic` +/// item. +fn find_fn_def(tcx: TyCtxt, diagnostic: &str) -> Option { + let attr_id = tcx + .all_diagnostic_items(()) + .name_to_id + .get(&rustc_span::symbol::Symbol::intern(diagnostic))?; + let TyKind::RigidTy(RigidTy::FnDef(def, _)) = + rustc_internal::stable(tcx.type_of(attr_id)).value.kind() + else { + return None; + }; + Some(def) +} diff --git a/kani-compiler/src/kani_middle/provide.rs b/kani-compiler/src/kani_middle/provide.rs index d5495acb67a7..d29635ecfd15 100644 --- a/kani-compiler/src/kani_middle/provide.rs +++ b/kani-compiler/src/kani_middle/provide.rs @@ -8,6 +8,7 @@ use crate::args::{Arguments, ReachabilityType}; use crate::kani_middle::intrinsics::ModelIntrinsics; use crate::kani_middle::reachability::{collect_reachable_items, filter_crate_items}; use crate::kani_middle::stubbing; +use crate::kani_middle::transform::BodyTransformation; use crate::kani_queries::QueryDb; use rustc_hir::def_id::{DefId, LocalDefId}; use rustc_middle::util::Providers; @@ -79,8 +80,9 @@ fn collect_and_partition_mono_items( rustc_smir::rustc_internal::run(tcx, || { let local_reachable = filter_crate_items(tcx, |_, _| true).into_iter().map(MonoItem::Fn).collect::>(); + // We do not actually need the value returned here. - collect_reachable_items(tcx, &local_reachable); + collect_reachable_items(tcx, &mut BodyTransformation::dummy(), &local_reachable); }) .unwrap(); (rustc_interface::DEFAULT_QUERY_PROVIDERS.collect_and_partition_mono_items)(tcx, key) diff --git a/kani-compiler/src/kani_middle/reachability.rs b/kani-compiler/src/kani_middle/reachability.rs index 2fa1bf057c1c..b46fbdccc016 100644 --- a/kani-compiler/src/kani_middle/reachability.rs +++ b/kani-compiler/src/kani_middle/reachability.rs @@ -37,12 +37,17 @@ use stable_mir::{CrateDef, ItemKind}; use crate::kani_middle::coercion; use crate::kani_middle::coercion::CoercionBase; use crate::kani_middle::stubbing::{get_stub, validate_instance}; +use crate::kani_middle::transform::BodyTransformation; /// Collect all reachable items starting from the given starting points. -pub fn collect_reachable_items(tcx: TyCtxt, starting_points: &[MonoItem]) -> Vec { +pub fn collect_reachable_items( + tcx: TyCtxt, + transformer: &mut BodyTransformation, + starting_points: &[MonoItem], +) -> Vec { // For each harness, collect items using the same collector. // I.e.: This will return any item that is reachable from one or more of the starting points. - let mut collector = MonoItemsCollector::new(tcx); + let mut collector = MonoItemsCollector::new(tcx, transformer); for item in starting_points { collector.collect(item.clone()); } @@ -92,7 +97,11 @@ where /// /// Probably only specifically useful with a predicate to find `TestDescAndFn` const declarations from /// tests and extract the closures from them. -pub fn filter_const_crate_items(tcx: TyCtxt, mut predicate: F) -> Vec +pub fn filter_const_crate_items( + tcx: TyCtxt, + transformer: &mut BodyTransformation, + mut predicate: F, +) -> Vec where F: FnMut(TyCtxt, Instance) -> bool, { @@ -103,7 +112,7 @@ where // Only collect monomorphic items. if let Ok(instance) = Instance::try_from(item) { if predicate(tcx, instance) { - let body = instance.body().unwrap(); + let body = transformer.body(tcx, instance); let mut collector = MonoItemsFnCollector { tcx, body: &body, @@ -118,9 +127,11 @@ where roots } -struct MonoItemsCollector<'tcx> { +struct MonoItemsCollector<'tcx, 'a> { /// The compiler context. tcx: TyCtxt<'tcx>, + /// The body transformation object used to retrieve a transformed body. + transformer: &'a mut BodyTransformation, /// Set of collected items used to avoid entering recursion loops. collected: FxHashSet, /// Items enqueued for visiting. @@ -129,14 +140,15 @@ struct MonoItemsCollector<'tcx> { call_graph: debug::CallGraph, } -impl<'tcx> MonoItemsCollector<'tcx> { - pub fn new(tcx: TyCtxt<'tcx>) -> Self { +impl<'tcx, 'a> MonoItemsCollector<'tcx, 'a> { + pub fn new(tcx: TyCtxt<'tcx>, transformer: &'a mut BodyTransformation) -> Self { MonoItemsCollector { tcx, collected: FxHashSet::default(), queue: vec![], #[cfg(debug_assertions)] call_graph: debug::CallGraph::default(), + transformer, } } @@ -174,7 +186,7 @@ impl<'tcx> MonoItemsCollector<'tcx> { fn visit_fn(&mut self, instance: Instance) -> Vec { let _guard = debug_span!("visit_fn", function=?instance).entered(); if validate_instance(self.tcx, instance) { - let body = instance.body().unwrap(); + let body = self.transformer.body(self.tcx, instance); let mut collector = MonoItemsFnCollector { tcx: self.tcx, collected: FxHashSet::default(), diff --git a/kani-compiler/src/kani_middle/transform/body.rs b/kani-compiler/src/kani_middle/transform/body.rs new file mode 100644 index 000000000000..32ffb373d767 --- /dev/null +++ b/kani-compiler/src/kani_middle/transform/body.rs @@ -0,0 +1,275 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT +// +//! Utility functions that allow us to modify a function body. + +use crate::kani_middle::find_fn_def; +use rustc_middle::ty::TyCtxt; +use stable_mir::mir::mono::Instance; +use stable_mir::mir::{ + BasicBlock, BasicBlockIdx, BinOp, Body, CastKind, Constant, Local, LocalDecl, Mutability, + Operand, Place, Rvalue, Statement, StatementKind, Terminator, TerminatorKind, UnwindAction, + VarDebugInfo, +}; +use stable_mir::ty::{Const, GenericArgs, Span, Ty, UintTy}; +use std::mem; + +/// This structure mimics a Body that can actually be modified. +pub struct MutableBody { + blocks: Vec, + + /// Declarations of locals within the function. + /// + /// The first local is the return value pointer, followed by `arg_count` + /// locals for the function arguments, followed by any user-declared + /// variables and temporaries. + locals: Vec, + + /// The number of arguments this function takes. + arg_count: usize, + + /// Debug information pertaining to user variables, including captures. + var_debug_info: Vec, + + /// Mark an argument (which must be a tuple) as getting passed as its individual components. + /// + /// This is used for the "rust-call" ABI such as closures. + spread_arg: Option, + + /// The span that covers the entire function body. + span: Span, +} + +impl MutableBody { + /// Get the basic blocks of this builder. + pub fn blocks(&self) -> &[BasicBlock] { + &self.blocks + } + + pub fn locals(&self) -> &[LocalDecl] { + &self.locals + } + + /// Create a mutable body from the original MIR body. + pub fn from(body: Body) -> Self { + MutableBody { + locals: body.locals().to_vec(), + arg_count: body.arg_locals().len(), + spread_arg: body.spread_arg(), + blocks: body.blocks, + var_debug_info: body.var_debug_info, + span: body.span, + } + } + + /// Create the new body consuming this mutable body. + pub fn into(self) -> Body { + Body::new( + self.blocks, + self.locals, + self.arg_count, + self.var_debug_info, + self.spread_arg, + self.span, + ) + } + + /// Add a new local to the body with the given attributes. + pub fn new_local(&mut self, ty: Ty, span: Span, mutability: Mutability) -> Local { + let decl = LocalDecl { ty, span, mutability }; + let local = self.locals.len(); + self.locals.push(decl); + local + } + + pub fn new_str_operand(&mut self, msg: &str, span: Span) -> Operand { + let literal = Const::from_str(msg); + Operand::Constant(Constant { span, user_ty: None, literal }) + } + + pub fn new_const_operand(&mut self, val: u128, uint_ty: UintTy, span: Span) -> Operand { + let literal = Const::try_from_uint(val, uint_ty).unwrap(); + Operand::Constant(Constant { span, user_ty: None, literal }) + } + + /// Create a raw pointer of `*mut type` and return a new local where that value is stored. + pub fn new_cast_ptr( + &mut self, + from: Operand, + pointee_ty: Ty, + mutability: Mutability, + before: &mut SourceInstruction, + ) -> Local { + assert!(from.ty(self.locals()).unwrap().kind().is_raw_ptr()); + let target_ty = Ty::new_ptr(pointee_ty, mutability); + let rvalue = Rvalue::Cast(CastKind::PtrToPtr, from, target_ty); + self.new_assignment(rvalue, before) + } + + /// Add a new assignment for the given binary operation. + /// + /// Return the local where the result is saved. + pub fn new_binary_op( + &mut self, + bin_op: BinOp, + lhs: Operand, + rhs: Operand, + before: &mut SourceInstruction, + ) -> Local { + let rvalue = Rvalue::BinaryOp(bin_op, lhs, rhs); + self.new_assignment(rvalue, before) + } + + /// Add a new assignment. + /// + /// Return local where the result is saved. + pub fn new_assignment(&mut self, rvalue: Rvalue, before: &mut SourceInstruction) -> Local { + let span = before.span(&self.blocks); + let ret_ty = rvalue.ty(&self.locals).unwrap(); + let result = self.new_local(ret_ty, span, Mutability::Not); + let stmt = Statement { kind: StatementKind::Assign(Place::from(result), rvalue), span }; + self.insert_stmt(stmt, before); + result + } + + /// Add a new assert to the basic block indicated by the given index. + /// + /// The new assertion will have the same span as the source instruction, and the basic block + /// will be split. The source instruction will be adjusted to point to the first instruction in + /// the new basic block. + pub fn add_check( + &mut self, + tcx: TyCtxt, + check_type: &CheckType, + source: &mut SourceInstruction, + value: Local, + msg: &str, + ) { + assert_eq!( + self.locals[value].ty, + Ty::bool_ty(), + "Expected boolean value as the assert input" + ); + let new_bb = self.blocks.len(); + let span = source.span(&self.blocks); + match check_type { + CheckType::Assert(assert_fn) => { + let assert_op = Operand::Copy(Place::from(self.new_local( + assert_fn.ty(), + span, + Mutability::Not, + ))); + let msg_op = self.new_str_operand(msg, span); + let kind = TerminatorKind::Call { + func: assert_op, + args: vec![Operand::Move(Place::from(value)), msg_op], + destination: Place { + local: self.new_local(Ty::new_tuple(&[]), span, Mutability::Not), + projection: vec![], + }, + target: Some(new_bb), + unwind: UnwindAction::Terminate, + }; + let terminator = Terminator { kind, span }; + self.split_bb(source, terminator); + } + CheckType::Panic(..) | CheckType::NoCore => { + tcx.sess + .dcx() + .struct_err("Failed to instrument the code. Cannot find `kani::assert`") + .with_note("Kani requires `kani` library in order to verify a crate.") + .emit(); + tcx.sess.dcx().abort_if_errors(); + unreachable!(); + } + } + } + + /// Split a basic block right before the source location and use the new terminator + /// in the basic block that was split. + /// + /// The source is updated to point to the same instruction which is now in the new basic block. + pub fn split_bb(&mut self, source: &mut SourceInstruction, new_term: Terminator) { + let new_bb_idx = self.blocks.len(); + let (idx, bb) = match source { + SourceInstruction::Statement { idx, bb } => { + let (orig_idx, orig_bb) = (*idx, *bb); + *idx = 0; + *bb = new_bb_idx; + (orig_idx, orig_bb) + } + SourceInstruction::Terminator { bb } => { + let orig_bb = *bb; + *bb = new_bb_idx; + (self.blocks[orig_bb].statements.len(), orig_bb) + } + }; + let old_term = mem::replace(&mut self.blocks[bb].terminator, new_term); + let bb_stmts = &mut self.blocks[bb].statements; + let remaining = bb_stmts.split_off(idx); + let new_bb = BasicBlock { statements: remaining, terminator: old_term }; + self.blocks.push(new_bb); + } + + /// Insert statement before the source instruction and update the source as needed. + pub fn insert_stmt(&mut self, new_stmt: Statement, before: &mut SourceInstruction) { + match before { + SourceInstruction::Statement { idx, bb } => { + self.blocks[*bb].statements.insert(*idx, new_stmt); + *idx += 1; + } + SourceInstruction::Terminator { bb } => { + // Append statements at the end of the basic block. + self.blocks[*bb].statements.push(new_stmt); + } + } + } +} + +#[derive(Clone, Debug)] +pub enum CheckType { + /// This is used by default when the `kani` crate is available. + Assert(Instance), + /// When the `kani` crate is not available, we have to model the check as an `if { panic!() }`. + Panic(Instance), + /// When building non-core crate, such as `rustc-std-workspace-core`, we cannot + /// instrument code, but we can still compile them. + NoCore, +} + +impl CheckType { + /// This will create the type of check that is available in the current crate. + /// + /// If `kani` crate is available, this will return [CheckType::Assert], and the instance will + /// point to `kani::assert`. Otherwise, we will collect the `core::panic_str` method and return + /// [CheckType::Panic]. + pub fn new(tcx: TyCtxt) -> CheckType { + if let Some(instance) = find_instance(tcx, "KaniAssert") { + CheckType::Assert(instance) + } else if let Some(instance) = find_instance(tcx, "panic_str") { + CheckType::Panic(instance) + } else { + CheckType::NoCore + } + } +} + +/// We store the index of an instruction to avoid borrow checker issues and unnecessary copies. +#[derive(Copy, Clone, Debug)] +pub enum SourceInstruction { + Statement { idx: usize, bb: BasicBlockIdx }, + Terminator { bb: BasicBlockIdx }, +} + +impl SourceInstruction { + pub fn span(&self, blocks: &[BasicBlock]) -> Span { + match *self { + SourceInstruction::Statement { idx, bb } => blocks[bb].statements[idx].span, + SourceInstruction::Terminator { bb } => blocks[bb].terminator.span, + } + } +} + +fn find_instance(tcx: TyCtxt, diagnostic: &str) -> Option { + Instance::resolve(find_fn_def(tcx, diagnostic)?, &GenericArgs(vec![])).ok() +} diff --git a/kani-compiler/src/kani_middle/transform/check_values.rs b/kani-compiler/src/kani_middle/transform/check_values.rs new file mode 100644 index 000000000000..aefc20b46a44 --- /dev/null +++ b/kani-compiler/src/kani_middle/transform/check_values.rs @@ -0,0 +1,924 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT +// +//! Implement a transformation pass that instrument the code to detect possible UB due to +//! the generation of an invalid value. +//! +//! This pass highly depend on Rust type layouts. For more details, see: +//! +//! +//! For that, we traverse the function body and look for unsafe operations that may generate +//! invalid values. For each operation found, we add checks to ensure the value is valid. +//! +//! Note: There is some redundancy in the checks that could be optimized. Example: +//! 1. We could merge the invalid values by the offset. +//! 2. We could avoid checking places that have been checked before. +use crate::args::ExtraChecks; +use crate::kani_middle::transform::body::{CheckType, MutableBody, SourceInstruction}; +use crate::kani_middle::transform::check_values::SourceOp::UnsupportedCheck; +use crate::kani_middle::transform::{TransformPass, TransformationType}; +use crate::kani_queries::QueryDb; +use rustc_middle::ty::TyCtxt; +use rustc_smir::rustc_internal; +use stable_mir::abi::{FieldsShape, Scalar, TagEncoding, ValueAbi, VariantsShape, WrappingRange}; +use stable_mir::mir::mono::{Instance, InstanceKind}; +use stable_mir::mir::visit::{Location, PlaceContext, PlaceRef}; +use stable_mir::mir::{ + AggregateKind, BasicBlockIdx, BinOp, Body, CastKind, Constant, FieldIdx, Local, LocalDecl, + MirVisitor, Mutability, NonDivergingIntrinsic, Operand, Place, ProjectionElem, Rvalue, + Statement, StatementKind, Terminator, TerminatorKind, +}; +use stable_mir::target::{MachineInfo, MachineSize}; +use stable_mir::ty::{AdtKind, Const, IndexedVal, RigidTy, Ty, TyKind, UintTy}; +use stable_mir::CrateDef; +use std::fmt::{Debug, Formatter}; +use strum_macros::AsRefStr; +use tracing::{debug, trace}; + +/// Instrument the code with checks for invalid values. +pub struct ValidValuePass { + check_type: CheckType, +} + +impl ValidValuePass { + pub fn new(tcx: TyCtxt) -> Self { + ValidValuePass { check_type: CheckType::new(tcx) } + } +} + +impl TransformPass for ValidValuePass { + fn transformation_type() -> TransformationType + where + Self: Sized, + { + TransformationType::Instrumentation + } + + fn is_enabled(&self, query_db: &QueryDb) -> bool + where + Self: Sized, + { + let args = query_db.args(); + args.ub_check.contains(&ExtraChecks::Validity) + } + + /// Transform the function body by inserting checks one-by-one. + /// For every unsafe dereference or a transmute operation, we check all values are valid. + fn transform(&self, tcx: TyCtxt, body: Body, instance: Instance) -> (bool, Body) { + trace!(function=?instance.name(), "transform"); + let mut new_body = MutableBody::from(body); + let orig_len = new_body.blocks().len(); + // Do not cache body.blocks().len() since it will change as we add new checks. + for bb_idx in 0..new_body.blocks().len() { + let Some(candidate) = + CheckValueVisitor::find_next(&new_body, bb_idx, bb_idx >= orig_len) + else { + continue; + }; + self.build_check(tcx, &mut new_body, candidate); + } + (orig_len != new_body.blocks().len(), new_body.into()) + } +} + +impl Debug for ValidValuePass { + /// Implement manually since MachineInfo doesn't currently derive Debug. + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + "ValidValuePass".fmt(f) + } +} + +impl ValidValuePass { + fn build_check(&self, tcx: TyCtxt, body: &mut MutableBody, instruction: UnsafeInstruction) { + debug!(?instruction, "build_check"); + let mut source = instruction.source; + for operation in instruction.operations { + match operation { + SourceOp::BytesValidity { ranges, target_ty, rvalue } => { + let value = body.new_assignment(rvalue, &mut source); + let rvalue_ptr = Rvalue::AddressOf(Mutability::Not, Place::from(value)); + for range in ranges { + let result = + self.build_limits(body, &range, rvalue_ptr.clone(), &mut source); + let msg = format!( + "Undefined Behavior: Invalid value of type `{}`", + // TODO: Fix pretty_ty + rustc_internal::internal(tcx, target_ty) + ); + body.add_check(tcx, &self.check_type, &mut source, result, &msg); + } + } + SourceOp::DerefValidity { pointee_ty, rvalue, ranges } => { + for range in ranges { + let result = self.build_limits(body, &range, rvalue.clone(), &mut source); + let msg = format!( + "Undefined Behavior: Invalid value of type `{}`", + // TODO: Fix pretty_ty + rustc_internal::internal(tcx, pointee_ty) + ); + body.add_check(tcx, &self.check_type, &mut source, result, &msg); + } + } + SourceOp::UnsupportedCheck { check, ty } => { + let reason = format!( + "Kani currently doesn't support checking validity of `{check}` for `{}` type", + rustc_internal::internal(tcx, ty) + ); + self.unsupported_check(tcx, body, &mut source, &reason); + } + } + } + } + + fn build_limits( + &self, + body: &mut MutableBody, + req: &ValidValueReq, + rvalue_ptr: Rvalue, + source: &mut SourceInstruction, + ) -> Local { + let span = source.span(body.blocks()); + debug!(?req, ?rvalue_ptr, ?span, "build_limits"); + let primitive_ty = uint_ty(req.size.bytes()); + let start_const = body.new_const_operand(req.valid_range.start, primitive_ty, span); + let end_const = body.new_const_operand(req.valid_range.end, primitive_ty, span); + let orig_ptr = if req.offset != 0 { + let start_ptr = move_local(body.new_assignment(rvalue_ptr, source)); + let byte_ptr = move_local(body.new_cast_ptr( + start_ptr, + Ty::unsigned_ty(UintTy::U8), + Mutability::Not, + source, + )); + let offset_const = body.new_const_operand(req.offset as _, UintTy::Usize, span); + let offset = move_local(body.new_assignment(Rvalue::Use(offset_const), source)); + move_local(body.new_binary_op(BinOp::Offset, byte_ptr, offset, source)) + } else { + move_local(body.new_assignment(rvalue_ptr, source)) + }; + let value_ptr = + body.new_cast_ptr(orig_ptr, Ty::unsigned_ty(primitive_ty), Mutability::Not, source); + let value = + Operand::Copy(Place { local: value_ptr, projection: vec![ProjectionElem::Deref] }); + let start_result = body.new_binary_op(BinOp::Ge, value.clone(), start_const, source); + let end_result = body.new_binary_op(BinOp::Le, value, end_const, source); + if req.valid_range.wraps_around() { + // valid >= start || valid <= end + body.new_binary_op( + BinOp::BitOr, + move_local(start_result), + move_local(end_result), + source, + ) + } else { + // valid >= start && valid <= end + body.new_binary_op( + BinOp::BitAnd, + move_local(start_result), + move_local(end_result), + source, + ) + } + } + + fn unsupported_check( + &self, + tcx: TyCtxt, + body: &mut MutableBody, + source: &mut SourceInstruction, + reason: &str, + ) { + let span = source.span(body.blocks()); + let rvalue = Rvalue::Use(Operand::Constant(Constant { + literal: Const::from_bool(false), + span, + user_ty: None, + })); + let result = body.new_assignment(rvalue, source); + body.add_check(tcx, &self.check_type, source, result, reason); + } +} + +fn move_local(local: Local) -> Operand { + Operand::Move(Place::from(local)) +} + +fn uint_ty(bytes: usize) -> UintTy { + match bytes { + 1 => UintTy::U8, + 2 => UintTy::U16, + 4 => UintTy::U32, + 8 => UintTy::U64, + 16 => UintTy::U128, + _ => unreachable!("Unexpected size: {bytes}"), + } +} + +/// Represent a requirement for the value stored in the given offset. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct ValidValueReq { + /// Offset in bytes. + offset: usize, + /// Size of this requirement. + size: MachineSize, + /// The range restriction is represented by a Scalar. + valid_range: WrappingRange, +} + +// TODO: Optimize checks by merging requirements whenever possible. +// There are a few cases that would need to be cover: +// 1- Ranges intersection is the same as one of the ranges (or both). +// 2- Ranges intersection is a new valid range. +// 3- Ranges intersection is a combination of two new ranges. +// 4- Intersection is empty. +impl ValidValueReq { + /// Only a type with `ValueAbi::Scalar` and `ValueAbi::ScalarPair` can be directly assigned an + /// invalid value directly. + /// + /// It's not possible to define a `rustc_layout_scalar_valid_range_*` to any other structure. + /// Note that this annotation only applies to the first scalar in the layout. + pub fn try_from_ty(machine_info: &MachineInfo, ty: Ty) -> Option { + let shape = ty.layout().unwrap().shape(); + match shape.abi { + ValueAbi::Scalar(Scalar::Initialized { value, valid_range }) + | ValueAbi::ScalarPair(Scalar::Initialized { value, valid_range }, _) => { + Some(ValidValueReq { offset: 0, size: value.size(machine_info), valid_range }) + } + ValueAbi::Scalar(_) + | ValueAbi::ScalarPair(_, _) + | ValueAbi::Uninhabited + | ValueAbi::Vector { .. } + | ValueAbi::Aggregate { .. } => None, + } + } + + /// Check if range is full. + pub fn is_full(&self) -> bool { + self.valid_range.is_full(self.size).unwrap() + } + + /// Check if this range contains `other` range. + /// + /// I.e., `scalar_2` ⊆ `scalar_1` + pub fn contains(&self, other: &ValidValueReq) -> bool { + assert_eq!(self.size, other.size); + match (self.valid_range.wraps_around(), other.valid_range.wraps_around()) { + (true, true) | (false, false) => { + self.valid_range.start <= other.valid_range.start + && self.valid_range.end >= other.valid_range.end + } + (true, false) => { + self.valid_range.start <= other.valid_range.start + || self.valid_range.end >= other.valid_range.end + } + (false, true) => self.is_full(), + } + } +} + +#[derive(AsRefStr, Clone, Debug)] +enum SourceOp { + /// Validity checks are done on a byte level when the Rvalue can generate invalid value. + /// + /// This variant tracks a location that is valid for its current type, but it may not be + /// valid for the given location in target type. This happens for: + /// - Transmute + /// - Field assignment + /// - Aggregate assignment + /// - Union Access + /// + /// Each range is a pair of offset and scalar that represents the valid values. + /// Note that the same offset may have multiple ranges that may require being joined. + BytesValidity { target_ty: Ty, rvalue: Rvalue, ranges: Vec }, + + /// Similar to BytesValidity, but it stores any dereference that may be unsafe. + /// + /// This can happen for: + /// - Raw pointer dereference + DerefValidity { pointee_ty: Ty, rvalue: Rvalue, ranges: Vec }, + + /// Represents a range check Kani currently does not support. + /// + /// This will translate into an assertion failure with an unsupported message. + /// There are many corner cases with the usage of #[rustc_layout_scalar_valid_range_*] + /// attribute. Such as valid ranges that do not intersect or enumeration with variants + /// with niche. + /// + /// Supporting all cases require significant work, and it is unlikely to exist in real world + /// code. To be on the sound side, we just emit an unsupported check, and users will need to + /// disable the check in person, and create a feature request for their case. + /// + /// TODO: Consider replacing the assertion(false) by an unsupported operation that emits a + /// compilation warning. + UnsupportedCheck { check: String, ty: Ty }, +} + +/// The unsafe instructions that may generate invalid values. +/// We need to instrument all operations to ensure the instruction is safe. +#[derive(Clone, Debug)] +struct UnsafeInstruction { + /// The instruction that depends on the potentially invalid value. + source: SourceInstruction, + /// The unsafe operations that may cause an invalid value in this instruction. + operations: Vec, +} + +/// Extract any source that may potentially trigger UB due to the generation of an invalid value. +/// +/// Generating an invalid value requires an unsafe operation, however, in MIR, it +/// may just be represented as a regular assignment. +/// +/// Thus, we have to instrument every assignment to an object that has niche and that the source +/// is an object of a different source, e.g.: +/// - Aggregate assignment +/// - Transmute +/// - MemCopy +/// - Cast +struct CheckValueVisitor<'a> { + locals: &'a [LocalDecl], + /// Whether we should skip the next instruction, since it might've been instrumented already. + /// When we instrument an instruction, we partition the basic block, and the instruction that + /// may trigger UB becomes the first instruction of the basic block, which we need to skip + /// later. + skip_next: bool, + /// The instruction being visited at a given point. + current: SourceInstruction, + /// The target instruction that should be verified. + pub target: Option, + /// The basic block being visited. + bb: BasicBlockIdx, + /// Machine information needed to calculate Niche. + machine: MachineInfo, +} + +impl<'a> CheckValueVisitor<'a> { + fn find_next( + body: &'a MutableBody, + bb: BasicBlockIdx, + skip_first: bool, + ) -> Option { + let mut visitor = CheckValueVisitor { + locals: body.locals(), + skip_next: skip_first, + current: SourceInstruction::Statement { idx: 0, bb }, + target: None, + bb, + machine: MachineInfo::target(), + }; + visitor.visit_basic_block(&body.blocks()[bb]); + visitor.target + } + + fn push_target(&mut self, op: SourceOp) { + let target = self + .target + .get_or_insert_with(|| UnsafeInstruction { source: self.current, operations: vec![] }); + target.operations.push(op); + } +} + +impl<'a> MirVisitor for CheckValueVisitor<'a> { + fn visit_statement(&mut self, stmt: &Statement, location: Location) { + if self.skip_next { + self.skip_next = false; + } else if self.target.is_none() { + // Leave it as an exhaustive match to be notified when a new kind is added. + match &stmt.kind { + StatementKind::Intrinsic(NonDivergingIntrinsic::CopyNonOverlapping(_)) => { + // Source and destination have the same type, so no invalid value cannot be + // generated. + } + StatementKind::Assign(place, rvalue) => { + // First check rvalue. + self.super_statement(stmt, location); + // Then check the destination place. + let ranges = assignment_check_points( + &self.machine, + self.locals, + place, + rvalue.ty(self.locals).unwrap(), + ); + if !ranges.is_empty() { + self.push_target(SourceOp::BytesValidity { + target_ty: self.locals[place.local].ty, + rvalue: rvalue.clone(), + ranges, + }); + } + } + StatementKind::FakeRead(_, _) + | StatementKind::SetDiscriminant { .. } + | StatementKind::Deinit(_) + | StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Retag(_, _) + | StatementKind::PlaceMention(_) + | StatementKind::AscribeUserType { .. } + | StatementKind::Coverage(_) + | StatementKind::ConstEvalCounter + | StatementKind::Intrinsic(NonDivergingIntrinsic::Assume(_)) + | StatementKind::Nop => self.super_statement(stmt, location), + } + } + + let SourceInstruction::Statement { idx, bb } = self.current else { unreachable!() }; + self.current = SourceInstruction::Statement { idx: idx + 1, bb }; + } + fn visit_terminator(&mut self, term: &Terminator, location: Location) { + if !(self.skip_next || self.target.is_some()) { + self.current = SourceInstruction::Terminator { bb: self.bb }; + // Leave it as an exhaustive match to be notified when a new kind is added. + match &term.kind { + TerminatorKind::Call { func, args, .. } => { + // Note: For transmute, both Src and Dst must be valid type. + // In this case, we need to save the Dst, and invoke super_terminator. + self.super_terminator(term, location); + let instance = expect_instance(self.locals, func); + if instance.kind == InstanceKind::Intrinsic { + match instance.intrinsic_name().unwrap().as_str() { + "write_bytes" => { + // The write bytes intrinsic may trigger UB in safe code. + // pub unsafe fn write_bytes(dst: *mut T, val: u8, count: usize) + // + // We don't support this operation yet. + let TyKind::RigidTy(RigidTy::RawPtr(target_ty, Mutability::Mut)) = + args[0].ty(self.locals).unwrap().kind() + else { + unreachable!() + }; + let validity = ty_validity_per_offset(&self.machine, target_ty, 0); + match validity { + Ok(ranges) if ranges.is_empty() => {} + _ => self.push_target(SourceOp::UnsupportedCheck { + check: "write_bytes".to_string(), + ty: target_ty, + }), + } + } + "transmute" | "transmute_copy" => { + unreachable!("Should've been lowered") + } + _ => {} + } + } + } + TerminatorKind::Goto { .. } + | TerminatorKind::SwitchInt { .. } + | TerminatorKind::Resume + | TerminatorKind::Abort + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::Drop { .. } + | TerminatorKind::Assert { .. } + | TerminatorKind::InlineAsm { .. } => self.super_terminator(term, location), + } + } + } + + fn visit_place(&mut self, place: &Place, ptx: PlaceContext, location: Location) { + for (idx, elem) in place.projection.iter().enumerate() { + let place_ref = PlaceRef { local: place.local, projection: &place.projection[..idx] }; + match elem { + ProjectionElem::Deref => { + let ptr_ty = place_ref.ty(self.locals).unwrap(); + if ptr_ty.kind().is_raw_ptr() { + let target_ty = elem.ty(ptr_ty).unwrap(); + let validity = ty_validity_per_offset(&self.machine, target_ty, 0); + match validity { + Ok(ranges) if !ranges.is_empty() => { + self.push_target(SourceOp::DerefValidity { + pointee_ty: target_ty, + rvalue: Rvalue::Use( + Operand::Copy(Place { + local: place_ref.local, + projection: place_ref.projection.to_vec(), + }) + .clone(), + ), + ranges, + }) + } + Err(_msg) => self.push_target(SourceOp::UnsupportedCheck { + check: "raw pointer dereference".to_string(), + ty: target_ty, + }), + _ => {} + } + } + } + ProjectionElem::Field(idx, target_ty) => { + if target_ty.kind().is_union() + && (!ptx.is_mutating() || place.projection.len() > idx + 1) + { + let validity = ty_validity_per_offset(&self.machine, *target_ty, 0); + match validity { + Ok(ranges) if !ranges.is_empty() => { + self.push_target(SourceOp::BytesValidity { + target_ty: *target_ty, + rvalue: Rvalue::Use(Operand::Copy(Place { + local: place_ref.local, + projection: place_ref.projection.to_vec(), + })), + ranges, + }) + } + Err(_msg) => self.push_target(SourceOp::UnsupportedCheck { + check: "union access".to_string(), + ty: *target_ty, + }), + _ => {} + } + } + } + ProjectionElem::Downcast(_) => {} + ProjectionElem::OpaqueCast(_) => {} + ProjectionElem::Subtype(_) => {} + ProjectionElem::Index(_) + | ProjectionElem::ConstantIndex { .. } + | ProjectionElem::Subslice { .. } => { /* safe */ } + } + } + self.super_place(place, ptx, location) + } + + fn visit_rvalue(&mut self, rvalue: &Rvalue, location: Location) { + match rvalue { + Rvalue::Cast(kind, op, dest_ty) => match kind { + CastKind::PtrToPtr => { + // For mutable raw pointer, if the type we are casting to is less restrictive + // than the original type, writing to the pointer could generate UB if the + // value is ever read again using the original pointer. + let TyKind::RigidTy(RigidTy::RawPtr(dest_pointee_ty, Mutability::Mut)) = + dest_ty.kind() + else { + // We only care about *mut T as *mut U + return; + }; + let src_ty = op.ty(self.locals).unwrap(); + debug!(?src_ty, ?dest_ty, "visit_rvalue mutcast"); + let TyKind::RigidTy(RigidTy::RawPtr(src_pointee_ty, _)) = src_ty.kind() else { + unreachable!() + }; + if let Ok(src_validity) = + ty_validity_per_offset(&self.machine, src_pointee_ty, 0) + { + if !src_validity.is_empty() { + if let Ok(dest_validity) = + ty_validity_per_offset(&self.machine, dest_pointee_ty, 0) + { + if dest_validity != src_validity { + self.push_target(SourceOp::UnsupportedCheck { + check: "mutable cast".to_string(), + ty: src_ty, + }) + } + } else { + self.push_target(SourceOp::UnsupportedCheck { + check: "mutable cast".to_string(), + ty: *dest_ty, + }) + } + } + } else { + self.push_target(SourceOp::UnsupportedCheck { + check: "mutable cast".to_string(), + ty: src_ty, + }) + } + } + CastKind::Transmute => { + debug!(?dest_ty, "transmute"); + // For transmute, we care about the destination type only. + // This could be optimized to only add a check if the requirements of the + // destination type are stricter than the source. + if let Ok(dest_validity) = ty_validity_per_offset(&self.machine, *dest_ty, 0) { + trace!(?dest_validity, "transmute"); + if !dest_validity.is_empty() { + self.push_target(SourceOp::BytesValidity { + target_ty: *dest_ty, + rvalue: rvalue.clone(), + ranges: dest_validity, + }) + } + } else { + self.push_target(SourceOp::UnsupportedCheck { + check: "transmute".to_string(), + ty: *dest_ty, + }) + } + } + CastKind::DynStar => self.push_target(UnsupportedCheck { + check: "Dyn*".to_string(), + ty: (rvalue.ty(self.locals).unwrap()), + }), + CastKind::PointerExposeAddress + | CastKind::PointerFromExposedAddress + | CastKind::PointerCoercion(_) + | CastKind::IntToInt + | CastKind::FloatToInt + | CastKind::FloatToFloat + | CastKind::IntToFloat + | CastKind::FnPtrToPtr => {} + }, + Rvalue::ShallowInitBox(_, _) => { + // The contents of the box is considered uninitialized. + // This should already be covered by the Assign detection. + } + Rvalue::Aggregate(kind, operands) => match kind { + // If the aggregated structure has invalid value, this could generate invalid value. + // But only if the operands don't have the exact same restrictions. + // This happens today with the usage of `rustc_layout_scalar_valid_range_*` + // attributes. + // In this case, only the value of the first member in memory can be restricted, + // thus, we only need to check the operand used to assign to the first in memory + // field. + AggregateKind::Adt(def, _variant, args, _, _) => { + if def.kind() == AdtKind::Struct { + let dest_ty = Ty::from_rigid_kind(RigidTy::Adt(*def, args.clone())); + if let Some(req) = ValidValueReq::try_from_ty(&self.machine, dest_ty) + && !req.is_full() + { + let dest_layout = dest_ty.layout().unwrap().shape(); + let first_op = + first_aggregate_operand(dest_ty, &dest_layout.fields, operands); + let first_ty = first_op.ty(self.locals).unwrap(); + // Rvalue must have same Abi layout except for range. + if !req.contains( + &ValidValueReq::try_from_ty(&self.machine, first_ty).unwrap(), + ) { + self.push_target(SourceOp::BytesValidity { + target_ty: dest_ty, + rvalue: Rvalue::Use(first_op), + ranges: vec![req], + }) + } + } + } + } + // Only aggregate value. + AggregateKind::Array(_) + | AggregateKind::Closure(_, _) + | AggregateKind::Coroutine(_, _, _) + | AggregateKind::Tuple => {} + }, + Rvalue::AddressOf(_, _) + | Rvalue::BinaryOp(_, _, _) + | Rvalue::CheckedBinaryOp(_, _, _) + | Rvalue::CopyForDeref(_) + | Rvalue::Discriminant(_) + | Rvalue::Len(_) + | Rvalue::Ref(_, _, _) + | Rvalue::Repeat(_, _) + | Rvalue::ThreadLocalRef(_) + | Rvalue::NullaryOp(_, _) + | Rvalue::UnaryOp(_, _) + | Rvalue::Use(_) => {} + } + self.super_rvalue(rvalue, location); + } +} + +/// Gets the operand that corresponds to the assignment of the first sized field in memory. +/// +/// The first field of a structure is the only one that can have extra value restrictions imposed +/// by `rustc_layout_scalar_valid_range_*` attributes. +/// +/// Note: This requires at least one operand to be sized and there's a 1:1 match between operands +/// and field types. +fn first_aggregate_operand(dest_ty: Ty, dest_shape: &FieldsShape, operands: &[Operand]) -> Operand { + let Some(first) = first_sized_field_idx(dest_ty, dest_shape) else { unreachable!() }; + operands[first].clone() +} + +/// Index of the first non_1zst fields in memory order. +fn first_sized_field_idx(ty: Ty, shape: &FieldsShape) -> Option { + if let TyKind::RigidTy(RigidTy::Adt(adt_def, args)) = ty.kind() + && adt_def.kind() == AdtKind::Struct + { + let offset_order = shape.fields_by_offset_order(); + let fields = adt_def.variants_iter().next().unwrap().fields(); + offset_order + .into_iter() + .find(|idx| !fields[*idx].ty_with_args(&args).layout().unwrap().shape().is_1zst()) + } else { + None + } +} + +/// An assignment to a field with invalid values is unsafe, and it may trigger UB if +/// the assigned value is invalid. +/// +/// This can only happen to the first in memory sized field of a struct, and only if the field +/// type invalid range is a valid value for the rvalue type. +fn assignment_check_points( + machine_info: &MachineInfo, + locals: &[LocalDecl], + place: &Place, + rvalue_ty: Ty, +) -> Vec { + let mut ty = locals[place.local].ty; + let Some(rvalue_range) = ValidValueReq::try_from_ty(machine_info, rvalue_ty) else { + // Rvalue Abi must be Scalar / ScalarPair since destination must be Scalar / ScalarPair. + return vec![]; + }; + let mut invalid_ranges = vec![]; + for proj in &place.projection { + match proj { + ProjectionElem::Field(field_idx, field_ty) => { + let shape = ty.layout().unwrap().shape(); + if first_sized_field_idx(ty, &shape.fields) == Some(*field_idx) + && let Some(dest_valid) = ValidValueReq::try_from_ty(machine_info, ty) + && !dest_valid.is_full() + && dest_valid.size == rvalue_range.size + { + if !dest_valid.contains(&rvalue_range) { + invalid_ranges.push(dest_valid) + } + } else { + // Invalidate collected ranges so far since we are no longer in the path of + // the first element. + invalid_ranges.clear(); + } + ty = *field_ty; + } + ProjectionElem::Deref + | ProjectionElem::Index(_) + | ProjectionElem::ConstantIndex { .. } + | ProjectionElem::Subslice { .. } + | ProjectionElem::Downcast(_) + | ProjectionElem::OpaqueCast(_) + | ProjectionElem::Subtype(_) => ty = proj.ty(ty).unwrap(), + }; + } + invalid_ranges +} + +/// Retrieve instance for the given function operand. +/// +/// This will panic if the operand is not a function or if it cannot be resolved. +fn expect_instance(locals: &[LocalDecl], func: &Operand) -> Instance { + let ty = func.ty(locals).unwrap(); + match ty.kind() { + TyKind::RigidTy(RigidTy::FnDef(def, args)) => Instance::resolve(def, &args).unwrap(), + _ => unreachable!(), + } +} + +/// Traverse the type and find all invalid values and their location in memory. +/// +/// Not all values are currently supported. For those not supported, we return Error. +fn ty_validity_per_offset( + machine_info: &MachineInfo, + ty: Ty, + current_offset: usize, +) -> Result, String> { + let layout = ty.layout().unwrap().shape(); + let ty_req = || { + if let Some(mut req) = ValidValueReq::try_from_ty(machine_info, ty) + && !req.is_full() + { + req.offset = current_offset; + vec![req] + } else { + vec![] + } + }; + match layout.fields { + FieldsShape::Primitive => Ok(ty_req()), + FieldsShape::Array { stride, count } if count > 0 => { + let TyKind::RigidTy(RigidTy::Array(elem_ty, _)) = ty.kind() else { unreachable!() }; + let elem_validity = ty_validity_per_offset(machine_info, elem_ty, current_offset)?; + let mut result = vec![]; + if !elem_validity.is_empty() { + for idx in 0..count { + let idx: usize = idx.try_into().unwrap(); + let elem_offset = idx * stride.bytes(); + let mut next_validity = elem_validity + .iter() + .cloned() + .map(|mut req| { + req.offset += elem_offset; + req + }) + .collect::>(); + result.append(&mut next_validity) + } + } + Ok(result) + } + FieldsShape::Arbitrary { ref offsets } => { + match ty.kind().rigid().unwrap() { + RigidTy::Adt(def, args) => { + match def.kind() { + AdtKind::Enum => { + // Support basic enumeration forms + let ty_variants = def.variants(); + match layout.variants { + VariantsShape::Single { index } => { + // Only one variant is reachable. This behaves like a struct. + let fields = ty_variants[index.to_index()].fields(); + let mut fields_validity = vec![]; + for idx in layout.fields.fields_by_offset_order() { + let field_offset = offsets[idx].bytes(); + let field_ty = fields[idx].ty_with_args(&args); + fields_validity.append(&mut ty_validity_per_offset( + machine_info, + field_ty, + field_offset + current_offset, + )?); + } + Ok(fields_validity) + } + VariantsShape::Multiple { + tag_encoding: TagEncoding::Niche { .. }, + .. + } => { + Err(format!("Unsupported Enum `{}` check", def.trimmed_name()))? + } + VariantsShape::Multiple { variants, .. } => { + let enum_validity = ty_req(); + let mut fields_validity = vec![]; + for (index, variant) in variants.iter().enumerate() { + let fields = ty_variants[index].fields(); + for field_idx in variant.fields.fields_by_offset_order() { + let field_offset = offsets[field_idx].bytes(); + let field_ty = fields[field_idx].ty_with_args(&args); + fields_validity.append(&mut ty_validity_per_offset( + machine_info, + field_ty, + field_offset + current_offset, + )?); + } + } + if fields_validity.is_empty() { + Ok(enum_validity) + } else { + Err(format!( + "Unsupported Enum `{}` check", + def.trimmed_name() + )) + } + } + } + } + AdtKind::Union => unreachable!(), + AdtKind::Struct => { + // If the struct range has niche add that. + let mut struct_validity = ty_req(); + let fields = def.variants_iter().next().unwrap().fields(); + for idx in layout.fields.fields_by_offset_order() { + let field_offset = offsets[idx].bytes(); + let field_ty = fields[idx].ty_with_args(&args); + struct_validity.append(&mut ty_validity_per_offset( + machine_info, + field_ty, + field_offset + current_offset, + )?); + } + Ok(struct_validity) + } + } + } + RigidTy::Tuple(tys) => { + let mut tuple_validity = vec![]; + for idx in layout.fields.fields_by_offset_order() { + let field_offset = offsets[idx].bytes(); + let field_ty = tys[idx]; + tuple_validity.append(&mut ty_validity_per_offset( + machine_info, + field_ty, + field_offset + current_offset, + )?); + } + Ok(tuple_validity) + } + RigidTy::Bool + | RigidTy::Char + | RigidTy::Int(_) + | RigidTy::Uint(_) + | RigidTy::Float(_) + | RigidTy::Never => { + unreachable!("Expected primitive layout for {ty:?}") + } + RigidTy::Str | RigidTy::Slice(_) | RigidTy::Array(_, _) => { + unreachable!("Expected array layout for {ty:?}") + } + RigidTy::RawPtr(_, _) | RigidTy::Ref(_, _, _) => { + // Fat pointer has arbitrary shape. + Ok(ty_req()) + } + RigidTy::FnDef(_, _) + | RigidTy::FnPtr(_) + | RigidTy::Closure(_, _) + | RigidTy::Coroutine(_, _, _) + | RigidTy::CoroutineWitness(_, _) + | RigidTy::Foreign(_) + | RigidTy::Dynamic(_, _, _) => Err(format!("Unsupported {ty:?}")), + } + } + FieldsShape::Union(_) | FieldsShape::Array { .. } => { + /* Anything is valid */ + Ok(vec![]) + } + } +} diff --git a/kani-compiler/src/kani_middle/transform/mod.rs b/kani-compiler/src/kani_middle/transform/mod.rs new file mode 100644 index 000000000000..a6cd17e8c7db --- /dev/null +++ b/kani-compiler/src/kani_middle/transform/mod.rs @@ -0,0 +1,135 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT +// +//! This module is responsible for optimizing and instrumenting function bodies. +//! +//! We make transformations on bodies already monomorphized, which allow us to make stronger +//! decisions based on the instance types and constants. +//! +//! The main downside is that some transformation that don't depend on the specialized type may be +//! applied multiple times, one per specialization. +//! +//! Another downside is that these modifications cannot be applied to concrete playback, since they +//! are applied on the top of StableMIR body, which cannot be propagated back to rustc's backend. +//! +//! # Warn +//! +//! For all instrumentation passes, always use exhaustive matches to ensure soundness in case a new +//! case is added. +use crate::kani_middle::transform::check_values::ValidValuePass; +use crate::kani_queries::QueryDb; +use rustc_middle::ty::TyCtxt; +use stable_mir::mir::mono::Instance; +use stable_mir::mir::Body; +use std::collections::HashMap; +use std::fmt::Debug; + +mod body; +mod check_values; + +/// Object used to retrieve a transformed instance body. +/// The transformations to be applied may be controlled by user options. +/// +/// The order however is always the same, we run optimizations first, and instrument the code +/// after. +#[derive(Debug)] +pub struct BodyTransformation { + /// The passes that may optimize the function body. + /// We store them separately from the instrumentation passes because we run the in specific order. + opt_passes: Vec>, + /// The passes that may add safety checks to the function body. + inst_passes: Vec>, + /// Cache transformation results. + cache: HashMap, +} + +impl BodyTransformation { + pub fn new(queries: &QueryDb, tcx: TyCtxt) -> Self { + let mut transformer = BodyTransformation { + opt_passes: vec![], + inst_passes: vec![], + cache: Default::default(), + }; + transformer.add_pass(queries, ValidValuePass::new(tcx)); + transformer + } + + /// Allow the creation of a dummy transformer that doesn't apply any transformation due to + /// the stubbing validation hack (see `collect_and_partition_mono_items` override. + /// Once we move the stubbing logic to a [TransformPass], we should be able to remove this. + pub fn dummy() -> Self { + BodyTransformation { opt_passes: vec![], inst_passes: vec![], cache: Default::default() } + } + + /// Retrieve the body of an instance. + /// + /// Note that this assumes that the instance does have a body since existing consumers already + /// assume that. Use `instance.has_body()` to check if an instance has a body. + pub fn body(&mut self, tcx: TyCtxt, instance: Instance) -> Body { + match self.cache.get(&instance) { + Some(TransformationResult::Modified(body)) => body.clone(), + Some(TransformationResult::NotModified) => instance.body().unwrap(), + None => { + let mut body = instance.body().unwrap(); + let mut modified = false; + for pass in self.opt_passes.iter().chain(self.inst_passes.iter()) { + let result = pass.transform(tcx, body, instance); + modified |= result.0; + body = result.1; + } + + let result = if modified { + TransformationResult::Modified(body.clone()) + } else { + TransformationResult::NotModified + }; + self.cache.insert(instance, result); + body + } + } + } + + fn add_pass(&mut self, query_db: &QueryDb, pass: P) { + if pass.is_enabled(&query_db) { + match P::transformation_type() { + TransformationType::Instrumentation => self.inst_passes.push(Box::new(pass)), + TransformationType::Optimization => { + unreachable!() + } + } + } + } +} + +/// The type of transformation that a pass may perform. +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +enum TransformationType { + /// Should only add assertion checks to ensure the program is correct. + Instrumentation, + /// May replace inefficient code with more performant but equivalent code. + #[allow(dead_code)] + Optimization, +} + +/// A trait to represent transformation passes that can be used to modify the body of a function. +trait TransformPass: Debug { + /// The type of transformation that this pass implements. + fn transformation_type() -> TransformationType + where + Self: Sized; + + fn is_enabled(&self, query_db: &QueryDb) -> bool + where + Self: Sized; + + /// Run a transformation pass in the function body. + fn transform(&self, tcx: TyCtxt, body: Body, instance: Instance) -> (bool, Body); +} + +/// The transformation result. +/// We currently only cache the body of functions that were instrumented. +#[derive(Clone, Debug)] +enum TransformationResult { + Modified(Body), + NotModified, +} diff --git a/kani-driver/src/call_single_file.rs b/kani-driver/src/call_single_file.rs index 4e8086e7e37b..24691943bc0f 100644 --- a/kani-driver/src/call_single_file.rs +++ b/kani-driver/src/call_single_file.rs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT use anyhow::Result; +use kani_metadata::UnstableFeature; use std::ffi::OsString; use std::path::{Path, PathBuf}; use std::process::Command; @@ -100,6 +101,10 @@ impl KaniSession { flags.push("--coverage-checks".into()); } + if self.args.common_args.unstable_features.contains(UnstableFeature::ValidValueChecks) { + flags.push("--ub-check=validity".into()) + } + flags.extend(self.args.common_args.unstable_features.as_arguments().map(str::to_string)); // This argument will select the Kani flavour of the compiler. It will be removed before diff --git a/kani_metadata/src/unstable.rs b/kani_metadata/src/unstable.rs index 3820f3f2238e..878b468dbdc3 100644 --- a/kani_metadata/src/unstable.rs +++ b/kani_metadata/src/unstable.rs @@ -84,6 +84,9 @@ pub enum UnstableFeature { FunctionContracts, /// Memory predicate APIs. MemPredicates, + /// Automatically check that no invalid value is produced which is considered UB in Rust. + /// Note that this does not include checking uninitialized value. + ValidValueChecks, } impl UnstableFeature { diff --git a/tests/kani/ValidValues/constants.rs b/tests/kani/ValidValues/constants.rs new file mode 100644 index 000000000000..5230e6e5e6cb --- /dev/null +++ b/tests/kani/ValidValues/constants.rs @@ -0,0 +1,40 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT +// kani-flags: -Z valid-value-checks +//! Check that Kani can identify UB when it is reading from a constant. +//! Note that this UB will be removed for `-Z mir-opt-level=2` + +#[kani::proof] +fn transmute_valid_bool() { + let _b = unsafe { std::mem::transmute::(1) }; +} + +#[kani::proof] +fn cast_to_valid_char() { + let _c = unsafe { *(&100u32 as *const u32 as *const char) }; +} + +#[kani::proof] +fn cast_to_valid_offset() { + let val = [100u32, 80u32]; + let _c = unsafe { *(&val as *const [u32; 2] as *const [char; 2]) }; +} + +#[kani::proof] +#[kani::should_panic] +fn transmute_invalid_bool() { + let _b = unsafe { std::mem::transmute::(2) }; +} + +#[kani::proof] +#[kani::should_panic] +fn cast_to_invalid_char() { + let _c = unsafe { *(&u32::MAX as *const u32 as *const char) }; +} + +#[kani::proof] +#[kani::should_panic] +fn cast_to_invalid_offset() { + let val = [100u32, u32::MAX]; + let _c = unsafe { *(&val as *const [u32; 2] as *const [char; 2]) }; +} diff --git a/tests/kani/ValidValues/custom_niche.rs b/tests/kani/ValidValues/custom_niche.rs new file mode 100644 index 000000000000..02b5b87bd092 --- /dev/null +++ b/tests/kani/ValidValues/custom_niche.rs @@ -0,0 +1,100 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT +// kani-flags: -Z valid-value-checks +//! Check that Kani can identify UB when using niche attribute for a custom operation. +#![feature(rustc_attrs)] + +use std::mem; +use std::mem::size_of; + +/// A possible implementation for a system of rating that defines niche. +/// A Rating represents the number of stars of a given product (1..=5). +#[rustc_layout_scalar_valid_range_start(1)] +#[rustc_layout_scalar_valid_range_end(5)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct Rating { + stars: u8, +} + +impl kani::Arbitrary for Rating { + fn any() -> Self { + let stars = kani::any_where(|s: &u8| *s >= 1 && *s <= 5); + unsafe { Rating { stars } } + } +} + +impl Rating { + /// Buggy version of new. Note that this still creates an invalid Rating. + /// + /// This is because `then_some` eagerly create the Rating value before assessing the condition. + /// Even though the value is never used, it is still considered UB. + pub fn new(value: u8) -> Option { + (value > 0 && value <= 5).then_some(unsafe { Rating { stars: value } }) + } + + pub unsafe fn new_unchecked(stars: u8) -> Rating { + Rating { stars } + } +} + +#[kani::proof] +#[kani::should_panic] +pub fn check_new_with_ub() { + assert_eq!(Rating::new(10), None); +} + +#[kani::proof] +#[kani::should_panic] +pub fn check_unchecked_new_ub() { + let val = kani::any(); + assert_eq!(unsafe { Rating::new_unchecked(val).stars }, val); +} + +#[kani::proof] +#[kani::should_panic] +pub fn check_new_with_ub_limits() { + let stars = kani::any_where(|s: &u8| *s == 0 || *s > 5); + let _ = Rating::new(stars); +} + +#[kani::proof] +#[kani::should_panic] +pub fn check_invalid_dereference() { + let any: u8 = kani::any(); + let _rating: Rating = unsafe { *(&any as *const _ as *const _) }; +} + +#[kani::proof] +#[kani::should_panic] +pub fn check_invalid_transmute() { + let any: u8 = kani::any(); + let _rating: Rating = unsafe { mem::transmute(any) }; +} + +#[kani::proof] +#[kani::should_panic] +pub fn check_invalid_transmute_copy() { + let any: u8 = kani::any(); + let _rating: Rating = unsafe { mem::transmute_copy(&any) }; +} + +#[kani::proof] +#[kani::should_panic] +pub fn check_invalid_increment() { + let mut orig: Rating = kani::any(); + unsafe { orig.stars += 1 }; +} + +#[kani::proof] +pub fn check_valid_increment() { + let mut orig: Rating = kani::any(); + kani::assume(orig.stars < 5); + unsafe { orig.stars += 1 }; +} + +/// Check that the compiler relies on valid value range of Rating to implement niche optimization. +#[kani::proof] +pub fn check_niche() { + assert_eq!(size_of::(), size_of::>()); + assert_eq!(size_of::(), size_of::>>()); +} diff --git a/tests/kani/ValidValues/non_null.rs b/tests/kani/ValidValues/non_null.rs new file mode 100644 index 000000000000..4874b61bf2d0 --- /dev/null +++ b/tests/kani/ValidValues/non_null.rs @@ -0,0 +1,26 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT +// kani-flags: -Z valid-value-checks +//! Check that Kani can identify UB when unsafely writing to NonNull. + +use std::num::NonZeroU8; +use std::ptr::{self, NonNull}; + +#[kani::proof] +#[kani::should_panic] +pub fn check_invalid_value() { + let _ = unsafe { NonNull::new_unchecked(ptr::null_mut::()) }; +} + +#[kani::proof] +#[kani::should_panic] +pub fn check_invalid_value_cfg() { + let nn = unsafe { NonNull::new_unchecked(ptr::null_mut::()) }; + // This should be unreachable. TODO: Make this expected test. + assert_ne!(unsafe { nn.as_ref() }, &10); +} + +#[kani::proof] +pub fn check_valid_dangling() { + let _ = unsafe { NonNull::new_unchecked(4 as *mut u32) }; +} diff --git a/tests/kani/ValidValues/write_invalid.rs b/tests/kani/ValidValues/write_invalid.rs new file mode 100644 index 000000000000..05d3705bd69a --- /dev/null +++ b/tests/kani/ValidValues/write_invalid.rs @@ -0,0 +1,37 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT +// kani-flags: -Z valid-value-checks + +//! Check that Kani can identify UB after writing an invalid value. +//! Writing invalid bytes is not UB as long as the incorrect value is not read. +//! However, we over-approximate for sake of simplicity and performance. + +use std::num::NonZeroU8; + +#[kani::proof] +#[kani::should_panic] +pub fn write_invalid_bytes_no_ub_with_spurious_cex() { + let mut non_zero: NonZeroU8 = kani::any(); + let dest = &mut non_zero as *mut _; + unsafe { std::intrinsics::write_bytes(dest, 0, 1) }; +} + +#[kani::proof] +#[kani::should_panic] +pub fn write_valid_before_read() { + let mut non_zero: NonZeroU8 = kani::any(); + let mut non_zero_2: NonZeroU8 = kani::any(); + let dest = &mut non_zero as *mut _; + unsafe { std::intrinsics::write_bytes(dest, 0, 1) }; + unsafe { std::intrinsics::write_bytes(dest, non_zero_2.get(), 1) }; + assert_eq!(non_zero, non_zero_2) +} + +#[kani::proof] +#[kani::should_panic] +pub fn read_invalid_is_ub() { + let mut non_zero: NonZeroU8 = kani::any(); + let dest = &mut non_zero as *mut _; + unsafe { std::intrinsics::write_bytes(dest, 0, 1) }; + assert_eq!(non_zero.get(), 0) +}