Skip to content

Commit

Permalink
Mover CheckerConfig to Infer (#950)
Browse files Browse the repository at this point in the history
  • Loading branch information
nilehmann authored Dec 17, 2024
1 parent a4b36ed commit 7073381
Show file tree
Hide file tree
Showing 14 changed files with 347 additions and 368 deletions.
7 changes: 4 additions & 3 deletions crates/flux-common/src/dbg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,17 @@ pub use crate::_basic_block_start as basic_block_start;

#[macro_export]
macro_rules! _statement{
($pos:literal, $stmt:expr, $rcx:expr, $env:expr, $span:expr, $checker:expr) => {{
($pos:literal, $stmt:expr, $infcx:expr, $env:expr, $span:expr, $checker:expr) => {{
if config::dump_checker_trace() {
let rcx = $infcx.rcx();
let ck = $checker;
let genv = ck.genv;
let local_names = &ck.body.local_names;
let local_decls = &ck.body.local_decls;
let rcx_json = RefineCtxtTrace::new(genv, $rcx);
let rcx_json = RefineCtxtTrace::new(genv, rcx);
let env_json = TypeEnvTrace::new(genv, local_names, local_decls, $env);
let span_json = SpanTrace::new(genv, $span);
tracing::info!(event = concat!("statement_", $pos), stmt = ?$stmt, stmt_span = ?$span, rcx = ?$rcx, env = ?$env, rcx_json = ?rcx_json, env_json = ?env_json, stmt_span_json = ?span_json)
tracing::info!(event = concat!("statement_", $pos), stmt = ?$stmt, stmt_span = ?$span, rcx = ?rcx, env = ?$env, rcx_json = ?rcx_json, env_json = ?env_json, stmt_span_json = ?span_json)
}
}};
}
Expand Down
13 changes: 2 additions & 11 deletions crates/flux-driver/src/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use flux_metadata::CStore;
use flux_middle::{fhir, global_env::GlobalEnv, queries::Providers, Specs};
use flux_refineck as refineck;
use itertools::Itertools;
use refineck::CheckerConfig;
use rustc_borrowck::consumers::ConsumerOptions;
use rustc_driver::{Callbacks, Compilation};
use rustc_errors::ErrorGuaranteed;
Expand Down Expand Up @@ -137,17 +136,11 @@ fn encode_and_save_metadata(genv: GlobalEnv) {
struct CrateChecker<'genv, 'tcx> {
genv: GlobalEnv<'genv, 'tcx>,
cache: FixQueryCache,
checker_config: CheckerConfig,
}

impl<'genv, 'tcx> CrateChecker<'genv, 'tcx> {
fn new(genv: GlobalEnv<'genv, 'tcx>) -> Self {
let crate_config = genv.crate_config().unwrap_or_default();
let checker_config = CheckerConfig {
check_overflow: crate_config.check_overflow,
scrape_quals: crate_config.scrape_quals,
};
CrateChecker { genv, cache: QueryCache::load(), checker_config }
CrateChecker { genv, cache: QueryCache::load() }
}

fn matches_check_def(&self, def_id: DefId) -> bool {
Expand Down Expand Up @@ -195,7 +188,7 @@ impl<'genv, 'tcx> CrateChecker<'genv, 'tcx> {

match self.genv.def_kind(def_id) {
DefKind::Fn | DefKind::AssocFn => {
refineck::check_fn(self.genv, &mut self.cache, def_id, self.checker_config)
refineck::check_fn(self.genv, &mut self.cache, def_id)
}
DefKind::Enum => {
let adt_def = self.genv.adt_def(def_id).emit(&self.genv)?;
Expand All @@ -212,7 +205,6 @@ impl<'genv, 'tcx> CrateChecker<'genv, 'tcx> {
def_id,
enum_def.invariants,
&adt_def,
self.checker_config,
)
}
DefKind::Struct => {
Expand All @@ -233,7 +225,6 @@ impl<'genv, 'tcx> CrateChecker<'genv, 'tcx> {
def_id,
struct_def.invariants,
&adt_def,
self.checker_config,
)
}
DefKind::Impl { of_trait } => {
Expand Down
4 changes: 2 additions & 2 deletions crates/flux-infer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ version = "0.1.0"
flux-common.workspace = true
flux-config.workspace = true
flux-errors.workspace = true
flux-middle.workspace = true
flux-macros.workspace = true
flux-middle.workspace = true

liquid-fixpoint.workspace = true
itertools.workspace = true
liquid-fixpoint.workspace = true
pad-adapter.workspace = true
serde.workspace = true
serde_json.workspace = true
Expand Down
8 changes: 2 additions & 6 deletions crates/flux-infer/src/fixpoint_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -775,12 +775,8 @@ pub struct KVarGen {
}

impl KVarGen {
pub fn new() -> Self {
Self { kvars: IndexVec::new(), dummy: false }
}

pub fn dummy() -> Self {
Self { kvars: IndexVec::new(), dummy: true }
pub(crate) fn new(dummy: bool) -> Self {
Self { kvars: IndexVec::new(), dummy }
}

fn get(&self, kvid: rty::KVid) -> &KVarDecl {
Expand Down
213 changes: 170 additions & 43 deletions crates/flux-infer/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,35 @@
use std::{cell::RefCell, fmt, iter};

use flux_common::{bug, tracked_span_assert_eq, tracked_span_dbg_assert_eq};
use flux_common::{bug, dbg, tracked_span_assert_eq, tracked_span_dbg_assert_eq};
use flux_config as config;
use flux_middle::{
global_env::GlobalEnv,
queries::{QueryErr, QueryResult},
query_bug,
rty::{
self,
canonicalize::Hoister,
evars::{EVarSol, UnsolvedEvar},
fold::TypeFoldable,
AliasKind, AliasTy, BaseTy, Binder, BoundVariableKinds, CoroutineObligPredicate, ESpan,
EVar, EVarGen, EarlyBinder, Expr, ExprKind, GenericArg, GenericArgs, HoleKind, InferMode,
Lambda, List, Loc, Mutability, Path, PolyVariant, PtrKind, RefineArgs, RefineArgsExt,
Lambda, List, Loc, Mutability, Name, Path, PolyVariant, PtrKind, RefineArgs, RefineArgsExt,
Region, Sort, Ty, TyKind, Var,
},
MaybeExternId,
};
use itertools::{izip, Itertools};
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_macros::extension;
use rustc_middle::{
mir::BasicBlock,
ty::{TyCtxt, Variance},
};
use rustc_span::Span;

use crate::{
fixpoint_encoding::{KVarEncoding, KVarGen},
refine_tree::{RefineCtxt, RefineTree, Scope, Snapshot},
fixpoint_encoding::{FixQueryCache, FixpointCtxt, KVarEncoding, KVarGen},
refine_tree::{AssumeInvariants, RefineCtxt, RefineTree, Scope, Snapshot, Unpacker},
};

pub type InferResult<T = ()> = std::result::Result<T, InferErr>;
Expand Down Expand Up @@ -71,38 +75,129 @@ pub enum ConstrReason {
Other,
}

/// Options that change the behavior of refinement type inference locally
#[derive(Clone, Copy)]
pub struct InferOpts {
/// Enable overflow checking. This affects the signature of primitive operations and the
/// invariants assumed for primitive types.
pub check_overflow: bool,
/// Whether qualifiers should be scraped from the constraint.
pub scrape_quals: bool,
}

pub struct InferCtxtRoot<'genv, 'tcx> {
pub genv: GlobalEnv<'genv, 'tcx>,
inner: RefCell<InferCtxtInner>,
refine_tree: RefineTree,
opts: InferOpts,
}

impl<'genv, 'tcx> InferCtxtRoot<'genv, 'tcx> {
pub fn new(
genv: GlobalEnv<'genv, 'tcx>,
root_id: DefId,
kvar_gen: KVarGen,
args: Option<&GenericArgs>,
) -> QueryResult<Self> {
Ok(Self {
genv,
inner: RefCell::new(InferCtxtInner::new(kvar_gen)),
refine_tree: RefineTree::new(genv, root_id, args)?,
pub struct InferCtxtRootBuilder<'genv, 'tcx> {
genv: GlobalEnv<'genv, 'tcx>,
opts: InferOpts,
root_id: DefId,
generic_args: Option<GenericArgs>,
dummy_kvars: bool,
}

#[extension(pub trait GlobalEnvExt<'genv, 'tcx>)]
impl<'genv, 'tcx> GlobalEnv<'genv, 'tcx> {
fn infcx_root(self, root_id: DefId, opts: InferOpts) -> InferCtxtRootBuilder<'genv, 'tcx> {
InferCtxtRootBuilder { genv: self, root_id, opts, generic_args: None, dummy_kvars: false }
}
}

impl<'genv, 'tcx> InferCtxtRootBuilder<'genv, 'tcx> {
pub fn with_dummy_kvars(mut self) -> Self {
self.dummy_kvars = true;
self
}

pub fn with_generic_args(mut self, generic_args: &GenericArgs) -> Self {
self.generic_args = Some(generic_args.clone());
self
}

pub fn build(self) -> QueryResult<InferCtxtRoot<'genv, 'tcx>> {
let genv = self.genv;
let mut params = genv
.generics_of(self.root_id)?
.const_params(genv)?
.into_iter()
.map(|(pcst, sort)| (Var::ConstGeneric(pcst), sort))
.collect_vec();
let offset = params.len();
self.genv.refinement_generics_of(self.root_id)?.fill_item(
self.genv,
&mut params,
&mut |param, index| {
let index = (index - offset) as u32;
let param = if let Some(args) = &self.generic_args {
param.instantiate(genv.tcx(), args, &[])
} else {
param.instantiate_identity()
};
let var = Var::EarlyParam(rty::EarlyReftParam { index, name: param.name });
(var, param.sort)
},
)?;

Ok(InferCtxtRoot {
genv: self.genv,
inner: RefCell::new(InferCtxtInner::new(self.dummy_kvars)),
refine_tree: RefineTree::new(params),
opts: self.opts,
})
}
}

impl<'genv, 'tcx> InferCtxtRoot<'genv, 'tcx> {
pub fn infcx<'a>(
&'a mut self,
def_id: DefId,
region_infcx: &'a rustc_infer::infer::InferCtxt<'tcx>,
) -> InferCtxt<'a, 'genv, 'tcx> {
InferCtxt::new(
self.genv,
InferCtxt {
genv: self.genv,
region_infcx,
def_id,
self.refine_tree.refine_ctxt_at_root(),
&self.inner,
)
rcx: self.refine_tree.refine_ctxt_at_root(),
inner: &self.inner,
check_overflow: self.opts.check_overflow,
}
}

pub fn fresh_kvar_in_scope(
&self,
binders: &[BoundVariableKinds],
scope: &Scope,
encoding: KVarEncoding,
) -> Expr {
let inner = &mut *self.inner.borrow_mut();
inner.kvars.fresh(binders, scope.iter(), encoding)
}

pub fn execute_fixpoint_query(
self,
cache: &mut FixQueryCache,
def_id: MaybeExternId,
ext: &'static str,
) -> QueryResult<Vec<Tag>> {
let mut refine_tree = self.refine_tree;
let kvars = self.inner.into_inner().kvars;
if config::dump_constraint() {
dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), ext, &refine_tree).unwrap();
}
refine_tree.simplify(self.genv.spec_func_defns()?);
if config::dump_constraint() {
let simp_ext = format!("simp.{}", ext);
dbg::dump_item_info(self.genv.tcx(), def_id.resolved_id(), simp_ext, &refine_tree)
.unwrap();
}

let mut fcx = FixpointCtxt::new(self.genv, def_id, kvars);
let cstr = refine_tree.into_fixpoint(&mut fcx)?;
fcx.check(cache, cstr, self.opts.scrape_quals)
}

pub fn split(self) -> (RefineTree, KVarGen) {
Expand All @@ -116,6 +211,7 @@ pub struct InferCtxt<'infcx, 'genv, 'tcx> {
pub def_id: DefId,
rcx: RefineCtxt<'infcx>,
inner: &'infcx RefCell<InferCtxtInner>,
pub check_overflow: bool,
}

struct InferCtxtInner {
Expand All @@ -124,22 +220,12 @@ struct InferCtxtInner {
}

impl InferCtxtInner {
fn new(kvars: KVarGen) -> Self {
Self { kvars, evars: Default::default() }
fn new(dummy_kvars: bool) -> Self {
Self { kvars: KVarGen::new(dummy_kvars), evars: Default::default() }
}
}

impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
fn new(
genv: GlobalEnv<'genv, 'tcx>,
region_infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
def_id: DefId,
rcx: RefineCtxt<'infcx>,
inner: &'infcx RefCell<InferCtxtInner>,
) -> Self {
Self { genv, region_infcx, def_id, rcx, inner }
}

pub fn clean_subtree(&mut self, snapshot: &Snapshot) {
self.rcx.clear_children(snapshot);
}
Expand Down Expand Up @@ -273,25 +359,66 @@ impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
pub fn tcx(&self) -> TyCtxt<'tcx> {
self.genv.tcx()
}
}

impl std::fmt::Debug for InferCtxt<'_, '_, '_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(&self.rcx, f)
pub fn rcx(&self) -> &RefineCtxt<'infcx> {
&self.rcx
}
}

impl<'infcx> std::ops::Deref for InferCtxt<'infcx, '_, '_> {
type Target = RefineCtxt<'infcx>;
/// Delegate methods to [`RefineCtxt`]
impl<'infcx> InferCtxt<'infcx, '_, '_> {
pub fn define_vars(&mut self, sort: &Sort) -> Expr {
self.rcx.define_vars(sort)
}

fn deref(&self) -> &Self::Target {
&self.rcx
pub fn define_var(&mut self, sort: &Sort) -> Name {
self.rcx.define_var(sort)
}

pub fn check_pred(&mut self, pred: impl Into<Expr>, tag: Tag) {
self.rcx.check_pred(pred, tag);
}

pub fn replace_evars(&mut self, evars: &EVarSol) {
self.rcx.replace_evars(evars);
}

pub fn assume_pred(&mut self, pred: impl Into<Expr>) {
self.rcx.assume_pred(pred);
}

pub fn unpack(&mut self, ty: &Ty) -> Ty {
self.hoister(false).hoist(ty)
}

pub fn snapshot(&self) -> Snapshot {
self.rcx.snapshot()
}

pub fn hoister(&mut self, assume_invariants: bool) -> Hoister<Unpacker<'_, 'infcx>> {
self.rcx.hoister(if assume_invariants {
AssumeInvariants::yes(self.check_overflow)
} else {
AssumeInvariants::No
})
}

pub fn scope(&self) -> Scope {
self.rcx.scope()
}

pub fn assume_invariants(&mut self, ty: &Ty) {
self.rcx.assume_invariants(ty, self.check_overflow);
}

fn check_impl(&mut self, pred1: impl Into<Expr>, pred2: impl Into<Expr>, tag: Tag) {
self.rcx.check_impl(pred1, pred2, tag);
}
}

impl std::ops::DerefMut for InferCtxt<'_, '_, '_> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.rcx
impl std::fmt::Debug for InferCtxt<'_, '_, '_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(&self.rcx, f)
}
}

Expand Down
Loading

0 comments on commit 7073381

Please sign in to comment.