Skip to content

Commit

Permalink
Add a lint against never type fallback affecting unsafe code
Browse files Browse the repository at this point in the history
  • Loading branch information
WaffleLapkin committed Apr 14, 2024
1 parent 114d7d1 commit 7dad013
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 11 deletions.
1 change: 1 addition & 0 deletions compiler/rustc_data_structures/src/graph/vec_graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use rustc_index::{Idx, IndexVec};
#[cfg(test)]
mod tests;

#[derive(Debug)]
pub struct VecGraph<N: Idx> {
/// Maps from a given node to an index where the set of successors
/// for that node starts. The index indexes into the `edges`
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_hir_typeck/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,7 @@ hir_typeck_use_is_empty =
hir_typeck_yield_expr_outside_of_coroutine =
yield expression outside of coroutine literal
hir_typeck_never_type_fallback_flowing_into_unsafe =
never type fallback affects this call to an `unsafe` function
.help = specify the type explicitly
5 changes: 5 additions & 0 deletions compiler/rustc_hir_typeck/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -639,3 +639,8 @@ pub enum SuggestBoxingForReturnImplTrait {
ends: Vec<Span>,
},
}

#[derive(LintDiagnostic)]
#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe)]
#[help]
pub struct NeverTypeFallbackFlowingIntoUnsafe {}
154 changes: 144 additions & 10 deletions compiler/rustc_hir_typeck/src/fallback.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use crate::FnCtxt;
use std::cell::OnceCell;

use crate::{errors, FnCtxt};
use rustc_data_structures::{
graph::{self, iterate::DepthFirstSearch, vec_graph::VecGraph},
unord::{UnordBag, UnordMap, UnordSet},
};
use rustc_hir::HirId;
use rustc_infer::infer::{DefineOpaqueTypes, InferOk};
use rustc_middle::ty::{self, Ty};
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitable};
use rustc_session::lint;
use rustc_span::Span;

#[derive(Copy, Clone)]
pub enum DivergingFallbackBehavior {
Expand Down Expand Up @@ -251,7 +256,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {

// Construct a coercion graph where an edge `A -> B` indicates
// a type variable is that is coerced
let coercion_graph = self.create_coercion_graph();
let (coercion_graph, coercion_graph2) = self.create_coercion_graph();

// Extract the unsolved type inference variable vids; note that some
// unsolved variables are integer/float variables and are excluded.
Expand Down Expand Up @@ -338,6 +343,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
// reach a member of N. If so, it falls back to `()`. Else
// `!`.
let mut diverging_fallback = UnordMap::with_capacity(diverging_vids.len());
let unsafe_infer_vars = OnceCell::new();
for &diverging_vid in &diverging_vids {
let diverging_ty = Ty::new_var(self.tcx, diverging_vid);
let root_vid = self.root_var(diverging_vid);
Expand All @@ -357,11 +363,35 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
output: infer_var_infos.items().any(|info| info.output),
};

let mut fallback_to = |ty| {
let unsafe_infer_vars = unsafe_infer_vars.get_or_init(|| {
let unsafe_infer_vars = compute_unsafe_infer_vars(self.root_ctxt, self.body_id);
debug!(?unsafe_infer_vars);
unsafe_infer_vars
});

let affected_unsafe_infer_vars =
graph::depth_first_search(&coercion_graph2, root_vid)
.filter_map(|x| unsafe_infer_vars.get(&x).copied())
.collect::<Vec<_>>();

for (hir_id, span) in affected_unsafe_infer_vars {
self.tcx.emit_node_span_lint(
lint::builtin::NEVER_TYPE_FALLBACK_FLOWING_INTO_UNSAFE,
hir_id,
span,
errors::NeverTypeFallbackFlowingIntoUnsafe {},
);
}

diverging_fallback.insert(diverging_ty, ty);
};

use DivergingFallbackBehavior::*;
match behavior {
FallbackToUnit => {
debug!("fallback to () - legacy: {:?}", diverging_vid);
diverging_fallback.insert(diverging_ty, self.tcx.types.unit);
fallback_to(self.tcx.types.unit);
}
FallbackToNiko => {
if found_infer_var_info.self_in_trait && found_infer_var_info.output {
Expand Down Expand Up @@ -390,21 +420,21 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
// set, see the relationship finding module in
// compiler/rustc_trait_selection/src/traits/relationships.rs.
debug!("fallback to () - found trait and projection: {:?}", diverging_vid);
diverging_fallback.insert(diverging_ty, self.tcx.types.unit);
fallback_to(self.tcx.types.unit);
} else if can_reach_non_diverging {
debug!("fallback to () - reached non-diverging: {:?}", diverging_vid);
diverging_fallback.insert(diverging_ty, self.tcx.types.unit);
fallback_to(self.tcx.types.unit);
} else {
debug!("fallback to ! - all diverging: {:?}", diverging_vid);
diverging_fallback.insert(diverging_ty, self.tcx.types.never);
fallback_to(self.tcx.types.never);
}
}
FallbackToNever => {
debug!(
"fallback to ! - `rustc_never_type_mode = \"fallback_to_never\")`: {:?}",
diverging_vid
);
diverging_fallback.insert(diverging_ty, self.tcx.types.never);
fallback_to(self.tcx.types.never);
}
NoFallback => {
debug!(
Expand All @@ -420,7 +450,9 @@ impl<'tcx> FnCtxt<'_, 'tcx> {

/// Returns a graph whose nodes are (unresolved) inference variables and where
/// an edge `?A -> ?B` indicates that the variable `?A` is coerced to `?B`.
fn create_coercion_graph(&self) -> VecGraph<ty::TyVid> {
///
/// The second element of the return tuple is a graph with edges in both directions.
fn create_coercion_graph(&self) -> (VecGraph<ty::TyVid>, VecGraph<ty::TyVid>) {
let pending_obligations = self.fulfillment_cx.borrow_mut().pending_obligations();
debug!("create_coercion_graph: pending_obligations={:?}", pending_obligations);
let coercion_edges: Vec<(ty::TyVid, ty::TyVid)> = pending_obligations
Expand Down Expand Up @@ -454,11 +486,113 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
.collect();
debug!("create_coercion_graph: coercion_edges={:?}", coercion_edges);
let num_ty_vars = self.num_ty_vars();
VecGraph::new(num_ty_vars, coercion_edges)

// This essentially creates a non-directed graph.
// Ideally we wouldn't do it like this, but it works ig :\
let doubly_connected = VecGraph::new(
num_ty_vars,
coercion_edges
.iter()
.copied()
.chain(coercion_edges.iter().copied().map(|(a, b)| (b, a)))
.collect(),
);

let normal = VecGraph::new(num_ty_vars, coercion_edges.clone());

(normal, doubly_connected)
}

/// If `ty` is an unresolved type variable, returns its root vid.
fn root_vid(&self, ty: Ty<'tcx>) -> Option<ty::TyVid> {
Some(self.root_var(self.shallow_resolve(ty).ty_vid()?))
}
}

/// Finds all type variables which are passed to an `unsafe` function.
///
/// For example, for this function `f`:
/// ```ignore (demonstrative)
/// fn f() {
/// unsafe {
/// let x /* ?X */ = core::mem::zeroed();
/// // ^^^^^^^^^^^^^^^^^^^ -- hir_id, span
///
/// let y = core::mem::zeroed::<Option<_ /* ?Y */>>();
/// // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -- hir_id, span
/// }
/// }
/// ```
///
/// Will return `{ id(?X) -> (hir_id, span) }`
fn compute_unsafe_infer_vars<'a, 'tcx>(
root_ctxt: &'a crate::TypeckRootCtxt<'tcx>,
body_id: rustc_span::def_id::LocalDefId,
) -> UnordMap<ty::TyVid, (HirId, Span)> {
use rustc_hir as hir;

let tcx = root_ctxt.infcx.tcx;
let body_id = tcx.hir().maybe_body_owned_by(body_id).unwrap();
let body = tcx.hir().body(body_id);
let mut res = <_>::default();

struct UnsafeInferVarsVisitor<'a, 'tcx, 'r> {
root_ctxt: &'a crate::TypeckRootCtxt<'tcx>,
res: &'r mut UnordMap<ty::TyVid, (HirId, Span)>,
}

use hir::intravisit::Visitor;
impl hir::intravisit::Visitor<'_> for UnsafeInferVarsVisitor<'_, '_, '_> {
fn visit_expr(&mut self, ex: &'_ hir::Expr<'_>) {
// FIXME: method calls
if let hir::ExprKind::Call(func, ..) = ex.kind {
let typeck_results = self.root_ctxt.typeck_results.borrow();

let func_ty = typeck_results.expr_ty(func);

// `is_fn` is required to ignore closures (which can't be unsafe)
if func_ty.is_fn()
&& let sig = func_ty.fn_sig(self.root_ctxt.infcx.tcx)
&& let hir::Unsafety::Unsafe = sig.unsafety()
{
let mut collector =
InferVarCollector { hir_id: ex.hir_id, call_span: ex.span, res: self.res };

// Collect generic arguments of the function which are inference variables
typeck_results
.node_args(ex.hir_id)
.types()
.for_each(|t| t.visit_with(&mut collector));

// Also check the return type, for cases like `(unsafe_fn::<_> as unsafe fn() -> _)()`
sig.output().visit_with(&mut collector);
}
}

hir::intravisit::walk_expr(self, ex);
}
}

struct InferVarCollector<'r> {
hir_id: HirId,
call_span: Span,
res: &'r mut UnordMap<ty::TyVid, (HirId, Span)>,
}

impl<'tcx> ty::TypeVisitor<TyCtxt<'tcx>> for InferVarCollector<'_> {
fn visit_ty(&mut self, t: Ty<'tcx>) {
if let Some(vid) = t.ty_vid() {
self.res.insert(vid, (self.hir_id, self.call_span));
} else {
use ty::TypeSuperVisitable as _;
t.super_visit_with(self)
}
}
}

UnsafeInferVarsVisitor { root_ctxt, res: &mut res }.visit_expr(&body.value);

debug!(?res, "collected the following unsafe vars for {body_id:?}");

res
}
4 changes: 3 additions & 1 deletion compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ mod arg_matrix;
mod checks;
mod suggestions;

use rustc_errors::ErrorGuaranteed;

use crate::coercion::DynamicCoerceMany;
use crate::fallback::DivergingFallbackBehavior;
use crate::fn_ctxt::checks::DivergingBlockBehavior;
use crate::{CoroutineTypes, Diverges, EnclosingBreakables, TypeckRootCtxt};
use hir::def_id::CRATE_DEF_ID;
use rustc_errors::{DiagCtxt, ErrorGuaranteed};
use rustc_errors::DiagCtxt;
use rustc_hir as hir;
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_hir_analysis::hir_ty_lowering::HirTyLowerer;
Expand Down
44 changes: 44 additions & 0 deletions compiler/rustc_lint_defs/src/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ declare_lint_pass! {
MISSING_FRAGMENT_SPECIFIER,
MUST_NOT_SUSPEND,
NAMED_ARGUMENTS_USED_POSITIONALLY,
NEVER_TYPE_FALLBACK_FLOWING_INTO_UNSAFE,
NON_CONTIGUOUS_RANGE_ENDPOINTS,
NON_EXHAUSTIVE_OMITTED_PATTERNS,
ORDER_DEPENDENT_TRAIT_OBJECTS,
Expand Down Expand Up @@ -4179,6 +4180,49 @@ declare_lint! {
"named arguments in format used positionally"
}

declare_lint! {
/// The `never_type_fallback_flowing_into_unsafe` lint detects cases where never type fallback
/// affects unsafe function calls.
///
/// ### Example
///
/// ```rust,compile_fail
/// #![deny(never_type_fallback_flowing_into_unsafe)]
/// fn main() {
/// if true {
/// // return has type `!` (never) which, is some cases, causes never type fallback
/// return
/// } else {
/// // `zeroed` is an unsafe function, which returns an unbounded type
/// unsafe { std::mem::zeroed() }
/// };
/// // depending on the fallback, `zeroed` may create `()` (which is completely sound),
/// // or `!` (which is instant undefined behavior)
/// }
/// ```
///
/// {{produces}}
///
/// ### Explanation
///
/// Due to historic reasons never type fallback were `()`, meaning that `!` got spontaneously
/// coerced to `()`. There are plans to change that, but they may make the code such as above
/// unsound. Instead of depending on the fallback, you should specify the type explicitly:
/// ```
/// if true {
/// return
/// } else {
/// // type is explicitly specified, fallback can't hurt us no more
/// unsafe { std::mem::zeroed::<()>() }
/// };
/// ```
///
/// See [Tracking Issue for making `!` fall back to `!`](https://github.com/rust-lang/rust/issues/123748).
pub NEVER_TYPE_FALLBACK_FLOWING_INTO_UNSAFE,
Warn,
"never type fallback affecting unsafe function calls"
}

declare_lint! {
/// The `byte_slice_in_packed_struct_with_derive` lint detects cases where a byte slice field
/// (`[u8]`) or string slice field (`str`) is used in a `packed` struct that derives one or
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//@ check-pass
use std::mem;

fn main() {
if false {
unsafe { mem::zeroed() }
//~^ warn: never type fallback affects this call to an `unsafe` function
} else {
return;
};

// no ; -> type is inferred without fallback
if true { unsafe { mem::zeroed() } } else { return }
}

// Minimization of the famous `objc` crate issue
fn _objc() {
pub unsafe fn send_message<R>() -> Result<R, ()> {
Ok(unsafe { core::mem::zeroed() })
}

macro_rules! msg_send {
() => {
match send_message::<_ /* ?0 */>() {
//~^ warn: never type fallback affects this call to an `unsafe` function
Ok(x) => x,
Err(_) => loop {},
}
};
}

unsafe {
msg_send!();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
warning: never type fallback affects this call to an `unsafe` function
--> $DIR/lint-never-type-fallback-flowing-into-unsafe.rs:6:18
|
LL | unsafe { mem::zeroed() }
| ^^^^^^^^^^^^^
|
= help: specify the type explicitly
= note: `#[warn(never_type_fallback_flowing_into_unsafe)]` on by default

warning: never type fallback affects this call to an `unsafe` function
--> $DIR/lint-never-type-fallback-flowing-into-unsafe.rs:24:19
|
LL | match send_message::<_ /* ?0 */>() {
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...
LL | msg_send!();
| ----------- in this macro invocation
|
= help: specify the type explicitly
= note: this warning originates in the macro `msg_send` (in Nightly builds, run with -Z macro-backtrace for more info)

warning: 2 warnings emitted

0 comments on commit 7dad013

Please sign in to comment.