diff --git a/crates/flux-infer/src/infer.rs b/crates/flux-infer/src/infer.rs index c1d51f1e51..fc80ebfe5f 100644 --- a/crates/flux-infer/src/infer.rs +++ b/crates/flux-infer/src/infer.rs @@ -11,7 +11,8 @@ use flux_middle::{ fold::TypeFoldable, AliasKind, AliasTy, BaseTy, Binder, BoundVariableKinds, CoroutineObligPredicate, ESpan, EVar, EVarGen, EarlyBinder, Expr, ExprKind, GenericArg, GenericArgs, HoleKind, InferMode, - Lambda, List, Mutability, Path, PolyVariant, PtrKind, Ref, Region, Sort, Ty, TyKind, Var, + Lambda, List, Loc, Mutability, Path, PolyVariant, PtrKind, Ref, Region, Sort, Ty, TyKind, + Var, }, }; use itertools::{izip, Itertools}; @@ -46,6 +47,14 @@ impl Tag { } } +#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)] +pub enum SubtypeReason { + Input, + Output, + Requires, + Ensures, +} + #[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)] pub enum ConstrReason { Call, @@ -57,6 +66,7 @@ pub enum ConstrReason { Rem, Goto(BasicBlock), Overflow, + Subtype(SubtypeReason), Other, } @@ -431,6 +441,7 @@ pub trait LocEnv { bound: Ty, ) -> InferResult; + fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult; fn get(&self, path: &Path) -> Ty; } @@ -444,6 +455,20 @@ struct Sub { obligations: Vec, } +/// [NOTE:unfold_strg_ref] We use this function to unfold a strong reference prior to a subtyping check. +/// Normally, when checking a function body, a `StrgRef` is automatically unfolded +/// i.e. `x:&strg T` is turned into turned into a `x:Ptr(l); l: T` where `l` is some +/// fresh location. However, we need the below to do a similar unfolding in `check_fn_subtyping` +/// where we just have the super-type signature that needs to be unfolded. +/// We also add the binding to the environment so that we can: +/// (1) UPDATE the location after the call, and +/// (2) CHECK the relevant `ensures` clauses of the super-sig. +/// Nico: More importantly, we are assuming functions always give back the "ownership" +/// of the location so even though we should technically "consume" the ownership and +/// remove the location from the environment, the type is always going to be overwritten. +/// (there's a check for this btw, if you write an &strg we require an ensures for that +/// location for the signature to be well-formed) + impl Sub { fn new(reason: ConstrReason, span: Span) -> Self { Self { reason, span, obligations: vec![] } @@ -475,6 +500,12 @@ impl Sub { infcx.unify_exprs(&path1.to_expr(), &path2.to_expr()); self.tys(infcx, &ty1, ty2) } + (TyKind::StrgRef(_, path1, ty1), TyKind::StrgRef(_, path2, ty2)) => { + env.unfold_strg_ref(infcx, path1, ty1)?; // see [NOTE:unfold_strg_ref] + let ty1 = env.get(path1); + infcx.unify_exprs(&path1.to_expr(), &path2.to_expr()); + self.tys(infcx, &ty1, ty2) + } (TyKind::Ptr(PtrKind::Mut(re), path), Ref!(_, bound, Mutability::Mut)) => { let mut at = infcx.at(self.span); env.ptr_to_ref(&mut at, ConstrReason::Call, *re, path, bound.clone())?; diff --git a/crates/flux-refineck/src/checker.rs b/crates/flux-refineck/src/checker.rs index 060a6a8c1e..c2e3b15a6b 100644 --- a/crates/flux-refineck/src/checker.rs +++ b/crates/flux-refineck/src/checker.rs @@ -1,10 +1,10 @@ use std::{collections::hash_map::Entry, iter}; -use flux_common::{bug, dbg, index::IndexVec, iter::IterExt, span_bug, tracked_span_bug}; +use flux_common::{bug, dbg, index::IndexVec, iter::IterExt, tracked_span_bug}; use flux_config as config; use flux_infer::{ fixpoint_encoding::{self, KVarGen}, - infer::{ConstrReason, InferCtxt, InferCtxtRoot}, + infer::{ConstrReason, InferCtxt, InferCtxtRoot, SubtypeReason}, refine_tree::{AssumeInvariants, RefineTree, Snapshot}, }; use flux_middle::{ @@ -13,10 +13,9 @@ use flux_middle::{ query_bug, rty::{ self, fold::TypeFoldable, refining::Refiner, AdtDef, BaseTy, Binder, Bool, Clause, - CoroutineObligPredicate, EarlyBinder, Ensures, Expr, FnOutput, FnTraitPredicate, - GenericArg, GenericArgs, GenericArgsExt as _, Int, IntTy, Mutability, Path, PolyFnSig, - PtrKind, Ref, RefineArgs, RefineArgsExt, Region::ReStatic, Ty, TyKind, Uint, UintTy, - VariantIdx, + CoroutineObligPredicate, EarlyBinder, Expr, FnOutput, FnTraitPredicate, GenericArg, + GenericArgs, GenericArgsExt as _, Int, IntTy, Mutability, Path, PolyFnSig, PtrKind, Ref, + RefineArgs, RefineArgsExt, Region::ReStatic, Ty, TyKind, Uint, UintTy, VariantIdx, }, }; use flux_rustc_bridge::{ @@ -235,7 +234,7 @@ fn find_trait_item( /// fn g(x1:T1,...,xn:Tn) -> T { /// f(x1,...,xn) /// } -/// TODO: copy rules from SLACK. +#[expect(clippy::too_many_arguments)] fn check_fn_subtyping( infcx: &mut InferCtxt, def_id: &DefId, @@ -243,6 +242,7 @@ fn check_fn_subtyping( sub_args: &[GenericArg], super_sig: EarlyBinder, super_args: Option<(&GenericArgs, &rty::RefineArgs)>, + overflow_checking: bool, span: Span, ) -> Result { let mut infcx = infcx.branch(); @@ -268,6 +268,8 @@ fn check_fn_subtyping( .map(|ty| infcx.unpack(ty)) .collect_vec(); + let mut env = TypeEnv::empty(); + let actuals = unfold_local_ptrs(&mut infcx, &mut env, span, &sub_sig, &actuals)?; let actuals = infer_under_mut_ref_hack(&mut infcx, &actuals[..], sub_sig.as_ref()); // 2. Fresh names for `T_f` refine-params / Instantiate fn_def_sig and normalize it @@ -280,27 +282,19 @@ fn check_fn_subtyping( .with_span(span)?; // 3. INPUT subtyping (g-input <: f-input) - // TODO: Check requires predicates (?) - // for requires in fn_def_sig.requires() { - // at.check_pred(requires, ConstrReason::Call); - // } - if !sub_sig.requires().is_empty() { - span_bug!(span, "Not yet handled: requires predicates {def_id:?}"); + for requires in super_sig.requires() { + infcx.assume_pred(requires); } for (actual, formal) in iter::zip(actuals, sub_sig.inputs()) { - let (formal, pred) = formal.unconstr(); - infcx.check_pred(&pred, ConstrReason::Call); - // see: TODO(pack-closure) - match (actual.kind(), formal.kind()) { - (TyKind::Ptr(rty::PtrKind::Mut(_), _), _) => { - bug!("Not yet handled: FnDef subtyping with Ptr"); - } - _ => { - infcx - .subtyping(&actual, &formal, ConstrReason::Call) - .with_span(span)?; - } - } + let reason = ConstrReason::Subtype(SubtypeReason::Input); + infcx + .fun_arg_subtyping(&mut env, &actual, formal, reason) + .with_span(span)?; + } + // we check the requires AFTER the actual-formal subtyping as the above may unfold stuff in the actuals + for requires in sub_sig.requires() { + let reason = ConstrReason::Subtype(SubtypeReason::Requires); + infcx.check_pred(requires, reason); } // 4. Plug in the EVAR solution / replace evars @@ -314,26 +308,32 @@ fn check_fn_subtyping( // 5. OUTPUT subtyping (f_out <: g_out) // RJ: new `at` to avoid borrowing errors...! infcx.push_scope(); - let oblig_output = super_sig + let super_output = super_sig .output() .replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode)); + let reason = ConstrReason::Subtype(SubtypeReason::Output); infcx - .subtyping(&output.ret, &oblig_output.ret, ConstrReason::Ret) + .subtyping(&output.ret, &super_output.ret, reason) .with_span(span)?; + + // 6. Update state with Output "ensures" and check super ensures + env.update_ensures(&mut infcx, &output, overflow_checking); + fold_local_ptrs(&mut infcx, &mut env, span)?; + env.check_ensures(&mut infcx, &super_output, ConstrReason::Subtype(SubtypeReason::Ensures)) + .with_span(span)?; + let evars_sol = infcx.pop_scope().with_span(span)?; infcx.replace_evars(&evars_sol); - if !output.ensures.is_empty() || !oblig_output.ensures.is_empty() { - span_bug!(span, "Not yet handled: subtyping with ensures predicates {def_id:?}"); - } - Ok(()) } + /// Trait subtyping check, which makes sure that the type for an impl method (def_id) /// is a subtype of the corresponding trait method. pub fn trait_impl_subtyping( genv: GlobalEnv, def_id: LocalDefId, + overflow_checking: bool, span: Span, ) -> Result> { // Skip the check if this is not an impl method @@ -367,6 +367,7 @@ pub fn trait_impl_subtyping( &impl_args, trait_fn_sig, Some((&trait_args, &trait_refine_args)), + overflow_checking, span, )?; Ok(()) @@ -733,18 +734,8 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> { .subtyping(&ret_place_ty, &output.ret, ConstrReason::Ret) .with_span(span)?; - for constraint in &output.ensures { - match constraint { - Ensures::Type(path, ty) => { - let actual_ty = env.get(path); - at.subtyping(&actual_ty, ty, ConstrReason::Ret) - .with_span(span)?; - } - Ensures::Pred(e) => { - at.check_pred(e, ConstrReason::Ret); - } - } - } + env.check_ensures(&mut at, &output, ConstrReason::Ret) + .with_span(span)?; let evars_sol = infcx.pop_scope().with_span(span)?; infcx.replace_evars(&evars_sol); @@ -831,16 +822,7 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> { .replace_evars(&evars_sol) .replace_bound_refts_with(|sort, _, _| infcx.define_vars(sort)); - for ensures in &output.ensures { - match ensures { - Ensures::Type(path, updated_ty) => { - let updated_ty = infcx.unpack(updated_ty); - infcx.assume_invariants(&updated_ty, self.check_overflow()); - env.update_path(path, updated_ty); - } - Ensures::Pred(e) => infcx.assume_pred(e), - } - } + env.update_ensures(infcx, &output, self.check_overflow()); fold_local_ptrs(infcx, env, span)?; @@ -890,7 +872,16 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> { // See `tests/neg/surface/fndef00.rs` let sub_sig = self.genv.fn_sig(def_id).with_span(span)?; let oblig_sig = fn_trait_pred.fndef_poly_sig(); - check_fn_subtyping(infcx, def_id, sub_sig, args, oblig_sig, None, span)?; + check_fn_subtyping( + infcx, + def_id, + sub_sig, + args, + oblig_sig, + None, + self.check_overflow(), + span, + )?; } _ => { // TODO: When we allow refining closure/fn at the surface level, we would need to do some function subtyping here, @@ -1369,6 +1360,7 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> { args, super_sig, None, + self.check_overflow(), stmt_span, )?; to diff --git a/crates/flux-refineck/src/lib.rs b/crates/flux-refineck/src/lib.rs index 2759bfdc73..9573f9c488 100644 --- a/crates/flux-refineck/src/lib.rs +++ b/crates/flux-refineck/src/lib.rs @@ -37,7 +37,7 @@ use flux_common::{cache::QueryCache, dbg, result::ResultExt as _}; use flux_config as config; use flux_infer::{ fixpoint_encoding::{FixpointCtxt, KVarGen}, - infer::{ConstrReason, Tag}, + infer::{ConstrReason, SubtypeReason, Tag}, refine_tree::RefineTree, }; use flux_macros::fluent_messages; @@ -169,7 +169,8 @@ pub fn check_fn( // PHASE 4: subtyping check for trait-method implementations if let Some((refine_tree, kvars)) = - trait_impl_subtyping(genv, local_id, span).map_err(|err| err.emit_err(&genv, def_id))? + trait_impl_subtyping(genv, local_id, config.check_overflow, span) + .map_err(|err| err.emit_err(&genv, def_id))? { tracing::info!("check_fn::refine-subtyping"); let errors = @@ -206,9 +207,15 @@ fn report_errors(genv: GlobalEnv, errors: Vec) -> Result<(), ErrorGuarantee for err in errors { let span = err.src_span; e = Some(match err.reason { - ConstrReason::Call => call_error(genv, span, err.dst_span), + ConstrReason::Call + | ConstrReason::Subtype(SubtypeReason::Input) + | ConstrReason::Subtype(SubtypeReason::Requires) => { + call_error(genv, span, err.dst_span) + } ConstrReason::Assign => genv.sess().emit_err(errors::AssignError { span }), - ConstrReason::Ret => ret_error(genv, span, err.dst_span), + ConstrReason::Ret + | ConstrReason::Subtype(SubtypeReason::Output) + | ConstrReason::Subtype(SubtypeReason::Ensures) => ret_error(genv, span, err.dst_span), ConstrReason::Div => genv.sess().emit_err(errors::DivError { span }), ConstrReason::Rem => genv.sess().emit_err(errors::RemError { span }), ConstrReason::Goto(_) => genv.sess().emit_err(errors::GotoError { span }), diff --git a/crates/flux-refineck/src/type_env.rs b/crates/flux-refineck/src/type_env.rs index d7fce01d37..8babfd6b39 100644 --- a/crates/flux-refineck/src/type_env.rs +++ b/crates/flux-refineck/src/type_env.rs @@ -16,9 +16,9 @@ use flux_middle::{ evars::EVarSol, fold::{FallibleTypeFolder, TypeFoldable, TypeVisitable, TypeVisitor}, region_matching::{rty_match_regions, ty_match_regions}, - BaseTy, Binder, BoundReftKind, Expr, ExprKind, FnSig, GenericArg, HoleKind, Lambda, List, - Loc, Mutability, Path, PtrKind, Region, SortCtor, SubsetTy, Ty, TyKind, VariantIdx, - INNERMOST, + BaseTy, Binder, BoundReftKind, Ensures, Expr, ExprKind, FnOutput, FnSig, GenericArg, + HoleKind, Lambda, List, Loc, Mutability, Path, PtrKind, Region, SortCtor, SubsetTy, Ty, + TyKind, VariantIdx, INNERMOST, }, PlaceExt as _, }; @@ -28,6 +28,7 @@ use flux_rustc_bridge::{ ty, }; use itertools::{izip, Itertools}; +use rustc_index::IndexSlice; use rustc_middle::{mir::RETURN_PLACE, ty::TyCtxt}; use rustc_type_ir::BoundVar; @@ -86,6 +87,10 @@ impl<'a> TypeEnv<'a> { env } + pub fn empty() -> TypeEnv<'a> { + TypeEnv { bindings: PlacesTree::default(), local_decls: IndexSlice::empty() } + } + fn alloc_with_ty(&mut self, local: Local, ty: Ty) { let ty = ty_match_regions(&ty, &self.local_decls[local].ty); self.bindings.insert(local.into(), LocKind::Local, ty); @@ -329,6 +334,25 @@ impl<'a> TypeEnv<'a> { Ok(loc) } + /// ``` + /// ----------------------------------- + /// Γ ; &strg <ℓ: t> => Γ,ℓ: t ; ptr(ℓ) + /// ``` + pub(crate) fn unfold_strg_ref( + &mut self, + infcx: &mut InferCtxt, + path: &Path, + ty: &Ty, + ) -> InferResult { + if let Some(loc) = path.to_loc() { + let ty = infcx.unpack(ty); + self.bindings.insert(loc, LocKind::Universal, ty); + Ok(loc) + } else { + bug!("unfold_strg_ref: unexpected path {path:?}") + } + } + pub(crate) fn unfold( &mut self, infcx: &mut InferCtxt, @@ -357,6 +381,44 @@ impl<'a> TypeEnv<'a> { self.bindings .fmap_mut(|binding| binding.replace_evars(evars)); } + + pub(crate) fn update_ensures( + &mut self, + infcx: &mut InferCtxt, + output: &FnOutput, + overflow_checking: bool, + ) { + for ensure in &output.ensures { + match ensure { + Ensures::Type(path, updated_ty) => { + let updated_ty = infcx.unpack(updated_ty); + infcx.assume_invariants(&updated_ty, overflow_checking); + self.update_path(path, updated_ty); + } + Ensures::Pred(e) => infcx.assume_pred(e), + } + } + } + + pub(crate) fn check_ensures( + &mut self, + at: &mut InferCtxtAt, + output: &FnOutput, + reason: ConstrReason, + ) -> InferResult { + for constraint in &output.ensures { + match constraint { + Ensures::Type(path, ty) => { + let actual_ty = self.get(path); + at.subtyping(&actual_ty, ty, reason)?; + } + Ensures::Pred(e) => { + at.check_pred(e, ConstrReason::Ret); + } + } + } + Ok(()) + } } pub(crate) enum PtrToRefBound { @@ -380,6 +442,10 @@ impl flux_infer::infer::LocEnv for TypeEnv<'_> { fn get(&self, path: &Path) -> Ty { self.get(path) } + + fn unfold_strg_ref(&mut self, infcx: &mut InferCtxt, path: &Path, ty: &Ty) -> InferResult { + self.unfold_strg_ref(infcx, path, ty) + } } impl BasicBlockEnvShape { diff --git a/tests/tests/neg/surface/impl01.rs b/tests/tests/neg/surface/impl01.rs new file mode 100644 index 0000000000..304d826e47 --- /dev/null +++ b/tests/tests/neg/surface/impl01.rs @@ -0,0 +1,14 @@ +pub trait Mono { + #[flux::sig(fn (zing: &strg i32[@n]) ensures zing: i32{v:n < v})] + fn foo(z: &mut i32); +} + +pub struct Horse; + +impl Mono for Horse { + #[flux::sig(fn (z: &strg i32[@n]) ensures z: i32{v:v < n})] + fn foo(z: &mut i32) { + //~^ ERROR refinement type + *z -= 1; + } +} diff --git a/tests/tests/neg/surface/impl02.rs b/tests/tests/neg/surface/impl02.rs new file mode 100644 index 0000000000..aa9c16de57 --- /dev/null +++ b/tests/tests/neg/surface/impl02.rs @@ -0,0 +1,29 @@ +pub trait Mono { + #[flux::sig(fn (zing: &strg i32[@n]) + requires 0 < n + ensures zing: i32{v:n < v})] + fn foo(z: &mut i32); +} + +pub struct Tiger; + +impl Mono for Tiger { + #[flux::sig(fn (pig: &strg i32[@m]) + requires 100 < m + ensures pig: i32{v:m < v})] + fn foo(pig: &mut i32) { + //~^ ERROR refinement type + *pig += 1; + } +} + +pub struct Snake; + +impl Mono for Snake { + #[flux::sig(fn (pig: &strg {i32[@m] | 100 < m}) + ensures pig: i32[m+1])] + fn foo(pig: &mut i32) { + //~^ ERROR refinement type + *pig += 1; + } +} diff --git a/tests/tests/neg/surface/impl03.rs b/tests/tests/neg/surface/impl03.rs new file mode 100644 index 0000000000..00ee27c9fc --- /dev/null +++ b/tests/tests/neg/surface/impl03.rs @@ -0,0 +1,14 @@ +pub trait Mono { + #[flux::sig(fn (zing: &mut i32{v: 0 <= v}))] + fn foo(z: &mut i32); +} + +pub struct Snake; + +impl Mono for Snake { + #[flux::sig(fn (hog: &strg i32[@m]) ensures hog: i32[m-1])] + fn foo(hog: &mut i32) { + //~^ ERROR: type invariant may not hold + *hog -= 1; + } +} diff --git a/tests/tests/pos/surface/impl01.rs b/tests/tests/pos/surface/impl01.rs new file mode 100644 index 0000000000..4455558806 --- /dev/null +++ b/tests/tests/pos/surface/impl01.rs @@ -0,0 +1,22 @@ +pub trait Mono { + #[flux::sig(fn (zing: &strg i32[@n]) ensures zing: i32{v:n < v})] + fn foo(z: &mut i32); +} + +pub struct Tiger; + +impl Mono for Tiger { + #[flux::sig(fn (pig: &strg i32[@m]) ensures pig: i32{v:m < v})] + fn foo(pig: &mut i32) { + *pig += 1; + } +} + +pub struct Snake; + +impl Mono for Snake { + #[flux::sig(fn (pig: &strg i32[@m]) ensures pig: i32[m+1])] + fn foo(pig: &mut i32) { + *pig += 1; + } +} diff --git a/tests/tests/pos/surface/impl02.rs b/tests/tests/pos/surface/impl02.rs new file mode 100644 index 0000000000..38f00e4505 --- /dev/null +++ b/tests/tests/pos/surface/impl02.rs @@ -0,0 +1,27 @@ +pub trait Mono { + #[flux::sig(fn (zing: &strg i32[@n]) + requires 100 < n + ensures zing: i32{v:n < v})] + fn foo(z: &mut i32); +} + +pub struct Tiger; + +impl Mono for Tiger { + #[flux::sig(fn (pig: &strg i32[@m]) + requires 0 < m + ensures pig: i32{v:m < v})] + fn foo(pig: &mut i32) { + *pig += 1; + } +} + +pub struct Snake; + +impl Mono for Snake { + #[flux::sig(fn (pig: &strg {i32[@m] | 0 < m}) + ensures pig: i32[m+1])] + fn foo(pig: &mut i32) { + *pig += 1; + } +} diff --git a/tests/tests/pos/surface/impl03.rs b/tests/tests/pos/surface/impl03.rs new file mode 100644 index 0000000000..196d52d9d8 --- /dev/null +++ b/tests/tests/pos/surface/impl03.rs @@ -0,0 +1,13 @@ +pub trait Mono { + #[flux::sig(fn (zing: &mut i32{v: 0 <= v}))] + fn foo(z: &mut i32); +} + +pub struct Tiger; + +impl Mono for Tiger { + #[flux::sig(fn (pig: &strg i32[@m]) ensures pig: i32[m+1])] + fn foo(pig: &mut i32) { + *pig += 1; + } +}