Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support &strg , requires and ensures in function subtyping #891

Merged
merged 12 commits into from
Nov 20, 2024
33 changes: 32 additions & 1 deletion crates/flux-infer/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use flux_middle::{
fold::TypeFoldable,
AliasKind, AliasTy, BaseTy, Binder, BoundVariableKinds, CoroutineObligPredicate, ESpan,
EVar, EVarGen, EarlyBinder, Expr, ExprKind, GenericArg, GenericArgs, HoleKind, InferMode,
Lambda, List, Mutability, Path, PolyVariant, PtrKind, Ref, Region, Sort, Ty, TyKind, Var,
Lambda, List, Loc, Mutability, Path, PolyVariant, PtrKind, Ref, Region, Sort, Ty, TyKind,
Var,
},
};
use itertools::{izip, Itertools};
Expand Down Expand Up @@ -46,6 +47,14 @@ impl Tag {
}
}

#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
pub enum SubtypeReason {
Input,
Output,
Requires,
Ensures,
}

#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
pub enum ConstrReason {
Call,
Expand All @@ -57,6 +66,7 @@ pub enum ConstrReason {
Rem,
Goto(BasicBlock),
Overflow,
Subtype(SubtypeReason),
Other,
}

Expand Down Expand Up @@ -431,6 +441,7 @@ pub trait LocEnv {
bound: Ty,
) -> InferResult<Ty>;

fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult<Loc>;
fn get(&self, path: &Path) -> Ty;
}

Expand All @@ -444,6 +455,20 @@ struct Sub {
obligations: Vec<rty::Clause>,
}

/// [NOTE:unfold_strg_ref] We use this function to unfold a strong reference prior to a subtyping check.
/// Normally, when checking a function body, a `StrgRef` is automatically unfolded
/// i.e. `x:&strg T` is turned into turned into a `x:Ptr(l); l: T` where `l` is some
/// fresh location. However, we need the below to do a similar unfolding in `check_fn_subtyping`
/// where we just have the super-type signature that needs to be unfolded.
/// We also add the binding to the environment so that we can:
/// (1) UPDATE the location after the call, and
/// (2) CHECK the relevant `ensures` clauses of the super-sig.
/// Nico: More importantly, we are assuming functions always give back the "ownership"
/// of the location so even though we should technically "consume" the ownership and
/// remove the location from the environment, the type is always going to be overwritten.
/// (there's a check for this btw, if you write an &strg we require an ensures for that
/// location for the signature to be well-formed)

impl Sub {
fn new(reason: ConstrReason, span: Span) -> Self {
Self { reason, span, obligations: vec![] }
Expand Down Expand Up @@ -475,6 +500,12 @@ impl Sub {
infcx.unify_exprs(&path1.to_expr(), &path2.to_expr());
self.tys(infcx, &ty1, ty2)
}
(TyKind::StrgRef(_, path1, ty1), TyKind::StrgRef(_, path2, ty2)) => {
env.unfold_strg_ref(infcx, path1, ty1)?; // see [NOTE:unfold_strg_ref]
let ty1 = env.get(path1);
infcx.unify_exprs(&path1.to_expr(), &path2.to_expr());
self.tys(infcx, &ty1, ty2)
}
(TyKind::Ptr(PtrKind::Mut(re), path), Ref!(_, bound, Mutability::Mut)) => {
let mut at = infcx.at(self.span);
env.ptr_to_ref(&mut at, ConstrReason::Call, *re, path, bound.clone())?;
Expand Down
102 changes: 47 additions & 55 deletions crates/flux-refineck/src/checker.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::{collections::hash_map::Entry, iter};

use flux_common::{bug, dbg, index::IndexVec, iter::IterExt, span_bug, tracked_span_bug};
use flux_common::{bug, dbg, index::IndexVec, iter::IterExt, tracked_span_bug};
use flux_config as config;
use flux_infer::{
fixpoint_encoding::{self, KVarGen},
infer::{ConstrReason, InferCtxt, InferCtxtRoot},
infer::{ConstrReason, InferCtxt, InferCtxtRoot, SubtypeReason},
refine_tree::{AssumeInvariants, RefineTree, Snapshot},
};
use flux_middle::{
Expand All @@ -13,10 +13,9 @@ use flux_middle::{
query_bug,
rty::{
self, fold::TypeFoldable, refining::Refiner, AdtDef, BaseTy, Binder, Bool, Clause,
CoroutineObligPredicate, EarlyBinder, Ensures, Expr, FnOutput, FnTraitPredicate,
GenericArg, GenericArgs, GenericArgsExt as _, Int, IntTy, Mutability, Path, PolyFnSig,
PtrKind, Ref, RefineArgs, RefineArgsExt, Region::ReStatic, Ty, TyKind, Uint, UintTy,
VariantIdx,
CoroutineObligPredicate, EarlyBinder, Expr, FnOutput, FnTraitPredicate, GenericArg,
GenericArgs, GenericArgsExt as _, Int, IntTy, Mutability, Path, PolyFnSig, PtrKind, Ref,
RefineArgs, RefineArgsExt, Region::ReStatic, Ty, TyKind, Uint, UintTy, VariantIdx,
},
};
use flux_rustc_bridge::{
Expand Down Expand Up @@ -235,14 +234,15 @@ fn find_trait_item(
/// fn g(x1:T1,...,xn:Tn) -> T {
/// f(x1,...,xn)
/// }
/// TODO: copy rules from SLACK.
#[expect(clippy::too_many_arguments)]
fn check_fn_subtyping(
infcx: &mut InferCtxt,
def_id: &DefId,
sub_sig: EarlyBinder<rty::PolyFnSig>,
sub_args: &[GenericArg],
super_sig: EarlyBinder<rty::PolyFnSig>,
super_args: Option<(&GenericArgs, &rty::RefineArgs)>,
overflow_checking: bool,
span: Span,
) -> Result {
let mut infcx = infcx.branch();
Expand All @@ -268,6 +268,8 @@ fn check_fn_subtyping(
.map(|ty| infcx.unpack(ty))
.collect_vec();

let mut env = TypeEnv::empty();
let actuals = unfold_local_ptrs(&mut infcx, &mut env, span, &sub_sig, &actuals)?;
let actuals = infer_under_mut_ref_hack(&mut infcx, &actuals[..], sub_sig.as_ref());

// 2. Fresh names for `T_f` refine-params / Instantiate fn_def_sig and normalize it
Expand All @@ -280,27 +282,19 @@ fn check_fn_subtyping(
.with_span(span)?;

// 3. INPUT subtyping (g-input <: f-input)
// TODO: Check requires predicates (?)
// for requires in fn_def_sig.requires() {
// at.check_pred(requires, ConstrReason::Call);
// }
if !sub_sig.requires().is_empty() {
span_bug!(span, "Not yet handled: requires predicates {def_id:?}");
for requires in super_sig.requires() {
infcx.assume_pred(requires);
}
for (actual, formal) in iter::zip(actuals, sub_sig.inputs()) {
let (formal, pred) = formal.unconstr();
infcx.check_pred(&pred, ConstrReason::Call);
// see: TODO(pack-closure)
match (actual.kind(), formal.kind()) {
(TyKind::Ptr(rty::PtrKind::Mut(_), _), _) => {
bug!("Not yet handled: FnDef subtyping with Ptr");
}
_ => {
infcx
.subtyping(&actual, &formal, ConstrReason::Call)
.with_span(span)?;
}
}
let reason = ConstrReason::Subtype(SubtypeReason::Input);
infcx
.fun_arg_subtyping(&mut env, &actual, formal, reason)
.with_span(span)?;
}
// we check the requires AFTER the actual-formal subtyping as the above may unfold stuff in the actuals
for requires in sub_sig.requires() {
let reason = ConstrReason::Subtype(SubtypeReason::Requires);
infcx.check_pred(requires, reason);
}

// 4. Plug in the EVAR solution / replace evars
Expand All @@ -314,26 +308,32 @@ fn check_fn_subtyping(
// 5. OUTPUT subtyping (f_out <: g_out)
// RJ: new `at` to avoid borrowing errors...!
infcx.push_scope();
let oblig_output = super_sig
let super_output = super_sig
.output()
.replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode));
let reason = ConstrReason::Subtype(SubtypeReason::Output);
infcx
.subtyping(&output.ret, &oblig_output.ret, ConstrReason::Ret)
.subtyping(&output.ret, &super_output.ret, reason)
.with_span(span)?;

// 6. Update state with Output "ensures" and check super ensures
env.update_ensures(&mut infcx, &output, overflow_checking);
fold_local_ptrs(&mut infcx, &mut env, span)?;
env.check_ensures(&mut infcx, &super_output, ConstrReason::Subtype(SubtypeReason::Ensures))
.with_span(span)?;

let evars_sol = infcx.pop_scope().with_span(span)?;
infcx.replace_evars(&evars_sol);

if !output.ensures.is_empty() || !oblig_output.ensures.is_empty() {
span_bug!(span, "Not yet handled: subtyping with ensures predicates {def_id:?}");
}

Ok(())
}

/// Trait subtyping check, which makes sure that the type for an impl method (def_id)
/// is a subtype of the corresponding trait method.
pub fn trait_impl_subtyping(
genv: GlobalEnv,
def_id: LocalDefId,
overflow_checking: bool,
span: Span,
) -> Result<Option<(RefineTree, KVarGen)>> {
// Skip the check if this is not an impl method
Expand Down Expand Up @@ -367,6 +367,7 @@ pub fn trait_impl_subtyping(
&impl_args,
trait_fn_sig,
Some((&trait_args, &trait_refine_args)),
overflow_checking,
span,
)?;
Ok(())
Expand Down Expand Up @@ -733,18 +734,8 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
.subtyping(&ret_place_ty, &output.ret, ConstrReason::Ret)
.with_span(span)?;

for constraint in &output.ensures {
match constraint {
Ensures::Type(path, ty) => {
let actual_ty = env.get(path);
at.subtyping(&actual_ty, ty, ConstrReason::Ret)
.with_span(span)?;
}
Ensures::Pred(e) => {
at.check_pred(e, ConstrReason::Ret);
}
}
}
env.check_ensures(&mut at, &output, ConstrReason::Ret)
.with_span(span)?;

let evars_sol = infcx.pop_scope().with_span(span)?;
infcx.replace_evars(&evars_sol);
Expand Down Expand Up @@ -831,16 +822,7 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
.replace_evars(&evars_sol)
.replace_bound_refts_with(|sort, _, _| infcx.define_vars(sort));

for ensures in &output.ensures {
match ensures {
Ensures::Type(path, updated_ty) => {
let updated_ty = infcx.unpack(updated_ty);
infcx.assume_invariants(&updated_ty, self.check_overflow());
env.update_path(path, updated_ty);
}
Ensures::Pred(e) => infcx.assume_pred(e),
}
}
env.update_ensures(infcx, &output, self.check_overflow());

fold_local_ptrs(infcx, env, span)?;

Expand Down Expand Up @@ -890,7 +872,16 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
// See `tests/neg/surface/fndef00.rs`
let sub_sig = self.genv.fn_sig(def_id).with_span(span)?;
let oblig_sig = fn_trait_pred.fndef_poly_sig();
check_fn_subtyping(infcx, def_id, sub_sig, args, oblig_sig, None, span)?;
check_fn_subtyping(
infcx,
def_id,
sub_sig,
args,
oblig_sig,
None,
self.check_overflow(),
span,
)?;
}
_ => {
// TODO: When we allow refining closure/fn at the surface level, we would need to do some function subtyping here,
Expand Down Expand Up @@ -1369,6 +1360,7 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
args,
super_sig,
None,
self.check_overflow(),
stmt_span,
)?;
to
Expand Down
15 changes: 11 additions & 4 deletions crates/flux-refineck/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use flux_common::{cache::QueryCache, dbg, result::ResultExt as _};
use flux_config as config;
use flux_infer::{
fixpoint_encoding::{FixpointCtxt, KVarGen},
infer::{ConstrReason, Tag},
infer::{ConstrReason, SubtypeReason, Tag},
refine_tree::RefineTree,
};
use flux_macros::fluent_messages;
Expand Down Expand Up @@ -169,7 +169,8 @@ pub fn check_fn(

// PHASE 4: subtyping check for trait-method implementations
if let Some((refine_tree, kvars)) =
trait_impl_subtyping(genv, local_id, span).map_err(|err| err.emit_err(&genv, def_id))?
trait_impl_subtyping(genv, local_id, config.check_overflow, span)
.map_err(|err| err.emit_err(&genv, def_id))?
{
tracing::info!("check_fn::refine-subtyping");
let errors =
Expand Down Expand Up @@ -206,9 +207,15 @@ fn report_errors(genv: GlobalEnv, errors: Vec<Tag>) -> Result<(), ErrorGuarantee
for err in errors {
let span = err.src_span;
e = Some(match err.reason {
ConstrReason::Call => call_error(genv, span, err.dst_span),
ConstrReason::Call
| ConstrReason::Subtype(SubtypeReason::Input)
| ConstrReason::Subtype(SubtypeReason::Requires) => {
call_error(genv, span, err.dst_span)
}
ConstrReason::Assign => genv.sess().emit_err(errors::AssignError { span }),
ConstrReason::Ret => ret_error(genv, span, err.dst_span),
ConstrReason::Ret
| ConstrReason::Subtype(SubtypeReason::Output)
| ConstrReason::Subtype(SubtypeReason::Ensures) => ret_error(genv, span, err.dst_span),
ConstrReason::Div => genv.sess().emit_err(errors::DivError { span }),
ConstrReason::Rem => genv.sess().emit_err(errors::RemError { span }),
ConstrReason::Goto(_) => genv.sess().emit_err(errors::GotoError { span }),
Expand Down
Loading
Loading