From c6ad442f68e64ea2573a21fb5002d58adf190bf5 Mon Sep 17 00:00:00 2001 From: Nico Lehmann Date: Mon, 16 Dec 2024 17:10:18 -0300 Subject: [PATCH] Put RefinementGenerics under EarlyBinder (#948) --- crates/flux-desugar/src/desugar.rs | 2 + crates/flux-fhir-analysis/locales/en-US.ftl | 10 +- crates/flux-fhir-analysis/src/conv/mod.rs | 136 +++++++++++++----- crates/flux-fhir-analysis/src/lib.rs | 11 +- crates/flux-fhir-analysis/src/wf/mod.rs | 26 ++-- crates/flux-fhir-analysis/src/wf/sortck.rs | 13 ++ crates/flux-infer/src/infer.rs | 7 +- crates/flux-infer/src/refine_tree.rs | 54 ++++--- crates/flux-metadata/src/lib.rs | 7 +- crates/flux-middle/src/cstore.rs | 5 +- crates/flux-middle/src/fhir.rs | 1 + crates/flux-middle/src/fhir/lift.rs | 2 +- crates/flux-middle/src/global_env.rs | 2 +- crates/flux-middle/src/queries.rs | 11 +- crates/flux-middle/src/rty/binder.rs | 8 +- crates/flux-middle/src/rty/mod.rs | 106 +++++++++----- crates/flux-middle/src/rty/refining.rs | 3 +- crates/flux-refineck/src/checker.rs | 4 +- .../wf/invalid_generic_arguments.rs | 17 +++ .../neg/error_messages/wf/refinement_args.rs | 5 + 20 files changed, 292 insertions(+), 138 deletions(-) create mode 100644 tests/tests/neg/error_messages/wf/invalid_generic_arguments.rs create mode 100644 tests/tests/neg/error_messages/wf/refinement_args.rs diff --git a/crates/flux-desugar/src/desugar.rs b/crates/flux-desugar/src/desugar.rs index 9b61305681..4d0b0d3737 100644 --- a/crates/flux-desugar/src/desugar.rs +++ b/crates/flux-desugar/src/desugar.rs @@ -721,6 +721,7 @@ impl<'a, 'genv, 'tcx: 'genv> RustItemCtxt<'a, 'genv, 'tcx> { let res = Res::Def(def_kind, def_id); fhir::Path { span, + fhir_id: self.next_fhir_id(), res, segments: self.genv.alloc_slice_fill_iter([fhir::PathSegment { ident: surface::Ident::new(lang_item.name(), span), @@ -1145,6 +1146,7 @@ trait DesugarCtxt<'genv, 'tcx: 'genv> { let proj_start = path.segments.len() - unresolved_segments; let fhir_path = fhir::Path { res: partial_res.base_res(), + fhir_id: self.next_fhir_id(), segments: try_alloc_slice!(self.genv(), &path.segments[..proj_start], |segment| { self.desugar_path_segment(segment) })?, diff --git a/crates/flux-fhir-analysis/locales/en-US.ftl b/crates/flux-fhir-analysis/locales/en-US.ftl index d42b238e09..1fc68da899 100644 --- a/crates/flux-fhir-analysis/locales/en-US.ftl +++ b/crates/flux-fhir-analysis/locales/en-US.ftl @@ -243,7 +243,7 @@ fhir_analysis_incorrect_generics_on_sort = *[other] {$expected} generic arguments } on sort -fhir_analysis_generics_on_type_parameter = +fhir_analysis_generics_on_sort_ty_param= type parameter expects no generics but found {$found} .label = found generics on sort type parameter @@ -258,6 +258,14 @@ fhir_analysis_generics_on_opaque_sort = fhir_analysis_refined_unrefinable_type = type cannot be refined +fhir_analysis_generics_on_prim_ty = + generic arguments are not allowed on builtin type `{$name}` + +fhir_analysis_generics_on_ty_param = + generic arguments are not allowed on type parameter `{$name}` + +fhir_analysis_generics_on_self_ty = + generic arguments are not allowed on self type # Check impl against trait errors diff --git a/crates/flux-fhir-analysis/src/conv/mod.rs b/crates/flux-fhir-analysis/src/conv/mod.rs index 20d9fe18c1..a0fa4c9ab2 100644 --- a/crates/flux-fhir-analysis/src/conv/mod.rs +++ b/crates/flux-fhir-analysis/src/conv/mod.rs @@ -30,7 +30,7 @@ use itertools::Itertools; use rustc_data_structures::fx::FxIndexMap; use rustc_errors::Diagnostic; use rustc_hash::FxHashSet; -use rustc_hir::{def::DefKind, def_id::DefId, OwnerId, PrimTy, Safety}; +use rustc_hir::{def::DefKind, def_id::DefId, OwnerId, Safety}; use rustc_middle::{ middle::resolve_bound_vars::ResolvedArg, ty::{self, AssocItem, AssocKind, BoundRegionKind::BrNamed, BoundVar, TyCtxt}, @@ -90,6 +90,10 @@ pub trait ConvPhase<'genv, 'tcx>: Sized { /// during the first phase to collect the sort of base types. fn insert_bty_sort(&mut self, fhir_id: FhirId, sort: rty::Sort); + /// Called after converting an path with the generic arguments. Using during the first phase + /// to instantiate sort of generic refinements. + fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs); + /// Called after converting an [`fhir::ExprKind::Alias`] with the sort of the resulting /// [`rty::AliasReft`]. Used during the first phase to collect the sorts of refinement aliases. fn insert_alias_reft_sort(&mut self, fhir_id: FhirId, fsort: rty::FuncSort); @@ -160,6 +164,8 @@ impl<'genv, 'tcx> ConvPhase<'genv, 'tcx> for AfterSortck<'_, 'genv, 'tcx> { fn insert_bty_sort(&mut self, _: FhirId, _: rty::Sort) {} + fn insert_path_args(&mut self, _: FhirId, _: rty::GenericArgs) {} + fn insert_alias_reft_sort(&mut self, _: FhirId, _: rty::FuncSort) {} } @@ -727,7 +733,7 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt

{ fhir::SortRes::SortParam(n) => return Ok(rty::Sort::Var(rty::ParamSort::from(n))), fhir::SortRes::TyParam(def_id) => { if !path.args.is_empty() { - let err = errors::GenericsOnTyParam::new( + let err = errors::GenericsOnSortTyParam::new( path.segments.last().unwrap().span, path.args.len(), ); @@ -811,7 +817,13 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt

{ prim_sort: fhir::PrimSort, ) -> QueryResult { if path.args.len() != prim_sort.generics() { - Err(emit_prim_sort_generics_error(self.genv(), path, prim_sort))?; + let err = errors::GenericsOnPrimitiveSort::new( + path.segments.last().unwrap().span, + prim_sort.name_str(), + path.args.len(), + prim_sort.generics(), + ); + Err(self.emit(err))?; } Ok(()) } @@ -1474,17 +1486,9 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt

{ path: &fhir::Path, ) -> QueryResult { let bty = match path.res { - fhir::Res::PrimTy(PrimTy::Bool) => rty::BaseTy::Bool, - fhir::Res::PrimTy(PrimTy::Str) => rty::BaseTy::Str, - fhir::Res::PrimTy(PrimTy::Char) => rty::BaseTy::Char, - fhir::Res::PrimTy(PrimTy::Int(int_ty)) => { - rty::BaseTy::Int(rustc_middle::ty::int_ty(int_ty)) - } - fhir::Res::PrimTy(PrimTy::Uint(uint_ty)) => { - rty::BaseTy::Uint(rustc_middle::ty::uint_ty(uint_ty)) - } - fhir::Res::PrimTy(PrimTy::Float(float_ty)) => { - rty::BaseTy::Float(rustc_middle::ty::float_ty(float_ty)) + fhir::Res::PrimTy(prim_ty) => { + self.check_prim_ty_generics(path, prim_ty)?; + prim_ty_to_bty(prim_ty) } fhir::Res::Def(DefKind::Struct | DefKind::Enum | DefKind::Union, did) => { let adt_def = self.genv().adt_def(did)?; @@ -1494,6 +1498,7 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt

{ fhir::Res::Def(DefKind::TyParam, def_id) => { let owner_id = ty_param_owner(self.genv(), def_id); let param_ty = def_id_to_param_ty(self.genv(), def_id); + self.check_ty_param_generics(path, param_ty)?; let param = self .genv() .generics_of(owner_id)? @@ -1507,6 +1512,7 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt

{ } } fhir::Res::SelfTyParam { trait_ } => { + self.check_self_ty_generics(path)?; let param = &self.genv().generics_of(trait_)?.own_params[0]; match param.kind { rty::GenericParamDefKind::Type { .. } => { @@ -1517,6 +1523,7 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt

{ } } fhir::Res::SelfTyAlias { alias_to, .. } => { + self.check_self_ty_generics(path)?; if P::EXPAND_TYPE_ALIASES { return Ok(self.genv().type_of(alias_to)?.instantiate_identity()); } else { @@ -1553,6 +1560,7 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt

{ } fhir::Res::Def(DefKind::TyAlias, def_id) => { let args = self.conv_generic_args(env, def_id, path.last_segment())?; + self.0.insert_path_args(path.fhir_id, args.clone()); let refine_args = path .refine .iter() @@ -1567,11 +1575,7 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt

{ } else { rty::BaseTy::Alias( rty::AliasKind::Weak, - rty::AliasTy { - def_id, - args: List::from(args), - refine_args: List::from(refine_args), - }, + rty::AliasTy { def_id, args, refine_args: List::from(refine_args) }, ) } } @@ -1606,10 +1610,10 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt

{ env: &mut Env, def_id: DefId, segment: &fhir::PathSegment, - ) -> QueryResult> { + ) -> QueryResult> { let mut into = vec![]; self.conv_generic_args_into(env, def_id, segment, &mut into)?; - Ok(into) + Ok(List::from(into)) } fn conv_generic_args_into( @@ -1746,6 +1750,51 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt

{ fn emit(&self, err: impl Diagnostic<'genv>) -> ErrorGuaranteed { self.genv().sess().emit_err(err) } + + fn check_prim_ty_generics( + &mut self, + path: &fhir::Path<'_>, + prim_ty: rustc_hir::PrimTy, + ) -> QueryResult { + if !path.last_segment().args.is_empty() { + let err = errors::GenericsOnPrimTy { span: path.span, name: prim_ty.name_str() }; + Err(self.emit(err))?; + } + Ok(()) + } + + fn check_ty_param_generics( + &mut self, + path: &fhir::Path<'_>, + param_ty: rty::ParamTy, + ) -> QueryResult { + if !path.last_segment().args.is_empty() { + let err = errors::GenericsOnTyParam { span: path.span, name: param_ty.name }; + Err(self.emit(err))?; + } + Ok(()) + } + + fn check_self_ty_generics(&mut self, path: &fhir::Path<'_>) -> QueryResult { + if !path.last_segment().args.is_empty() { + let err = errors::GenericsOnSelfTy { span: path.span }; + Err(self.emit(err))?; + } + Ok(()) + } +} + +fn prim_ty_to_bty(prim_ty: rustc_hir::PrimTy) -> rty::BaseTy { + match prim_ty { + rustc_hir::PrimTy::Int(int_ty) => rty::BaseTy::Int(rustc_middle::ty::int_ty(int_ty)), + rustc_hir::PrimTy::Uint(uint_ty) => rty::BaseTy::Uint(rustc_middle::ty::uint_ty(uint_ty)), + rustc_hir::PrimTy::Float(float_ty) => { + rty::BaseTy::Float(rustc_middle::ty::float_ty(float_ty)) + } + rustc_hir::PrimTy::Str => rty::BaseTy::Str, + rustc_hir::PrimTy::Bool => rty::BaseTy::Bool, + rustc_hir::PrimTy::Char => rty::BaseTy::Char, + } } /// Conversion of expressions @@ -2177,20 +2226,6 @@ pub fn conv_func_decl(genv: GlobalEnv, func: &fhir::SpecFunc) -> QueryResult ErrorGuaranteed { - let err = errors::GenericsOnPrimitiveSort::new( - path.segments.last().unwrap().span, - prim_sort.name_str(), - path.args.len(), - prim_sort.generics(), - ); - genv.sess().emit_err(err) -} - fn conv_lit(lit: fhir::Lit) -> rty::Constant { match lit { fhir::Lit::Int(n) => rty::Constant::from(n), @@ -2242,7 +2277,7 @@ mod errors { use flux_macros::Diagnostic; use flux_middle::{fhir, global_env::GlobalEnv}; use rustc_hir::def_id::DefId; - use rustc_span::{symbol::Ident, Span}; + use rustc_span::{symbol::Ident, Span, Symbol}; #[derive(Diagnostic)] #[diag(fhir_analysis_assoc_type_not_found, code = E0999)] @@ -2421,15 +2456,15 @@ mod errors { } #[derive(Diagnostic)] - #[diag(fhir_analysis_generics_on_type_parameter, code = E0999)] - pub(super) struct GenericsOnTyParam { + #[diag(fhir_analysis_generics_on_sort_ty_param, code = E0999)] + pub(super) struct GenericsOnSortTyParam { #[primary_span] #[label] span: Span, found: usize, } - impl GenericsOnTyParam { + impl GenericsOnSortTyParam { pub(super) fn new(span: Span, found: usize) -> Self { Self { span, found } } @@ -2464,4 +2499,27 @@ mod errors { Self { span, found } } } + + #[derive(Diagnostic)] + #[diag(fhir_analysis_generics_on_prim_ty, code = E0999)] + pub(super) struct GenericsOnPrimTy { + #[primary_span] + pub span: Span, + pub name: &'static str, + } + + #[derive(Diagnostic)] + #[diag(fhir_analysis_generics_on_ty_param, code = E0999)] + pub(super) struct GenericsOnTyParam { + #[primary_span] + pub span: Span, + pub name: Symbol, + } + + #[derive(Diagnostic)] + #[diag(fhir_analysis_generics_on_self_ty, code = E0999)] + pub(super) struct GenericsOnSelfTy { + #[primary_span] + pub span: Span, + } } diff --git a/crates/flux-fhir-analysis/src/lib.rs b/crates/flux-fhir-analysis/src/lib.rs index f69ce3d786..43ee861590 100644 --- a/crates/flux-fhir-analysis/src/lib.rs +++ b/crates/flux-fhir-analysis/src/lib.rs @@ -400,11 +400,11 @@ fn generics_of(genv: GlobalEnv, def_id: LocalDefId) -> QueryResult QueryResult { +) -> QueryResult> { let parent = genv.tcx().generics_of(local_id).parent; let parent_count = if let Some(def_id) = parent { genv.refinement_generics_of(def_id)?.count() } else { 0 }; - match genv.map().node(local_id)? { + let generics = match genv.map().node(local_id)? { fhir::Node::Item(fhir::Item { kind: fhir::ItemKind::Fn(..) | fhir::ItemKind::TyAlias(..), generics, @@ -420,10 +420,11 @@ fn refinement_generics_of( }) => { let wfckresults = genv.check_wf(local_id)?; let params = conv::conv_refinement_generics(generics.refinement_params, &wfckresults)?; - Ok(rty::RefinementGenerics { parent, parent_count, own_params: params }) + rty::RefinementGenerics { parent, parent_count, own_params: params } } - _ => Ok(rty::RefinementGenerics { parent, parent_count, own_params: rty::List::empty() }), - } + _ => rty::RefinementGenerics { parent, parent_count, own_params: rty::List::empty() }, + }; + Ok(rty::EarlyBinder(generics)) } fn type_of(genv: GlobalEnv, def_id: LocalDefId) -> QueryResult> { diff --git a/crates/flux-fhir-analysis/src/wf/mod.rs b/crates/flux-fhir-analysis/src/wf/mod.rs index 8fe71bf1f8..7f0caea7ee 100644 --- a/crates/flux-fhir-analysis/src/wf/mod.rs +++ b/crates/flux-fhir-analysis/src/wf/mod.rs @@ -6,8 +6,6 @@ mod errors; mod param_usage; mod sortck; -use std::iter; - use flux_common::result::{ErrorCollector, ResultExt as _}; use flux_errors::{Errors, FluxSession}; use flux_middle::{ @@ -377,31 +375,29 @@ impl<'genv> fhir::visit::Visitor<'genv> for Wf<'_, 'genv, '_> { } fn visit_path(&mut self, path: &fhir::Path<'genv>) { + let genv = self.genv(); if let fhir::Res::Def(DefKind::TyAlias, def_id) = path.res { - let Some(generics) = self - .infcx - .genv - .refinement_generics_of(def_id) - .emit(&self.errors) - .ok() - else { + let Ok(generics) = genv.refinement_generics_of(def_id).emit(&self.errors) else { return; }; - if path.refine.len() != generics.own_params.len() { + if path.refine.len() != generics.count() { self.errors.emit(errors::EarlyBoundArgCountMismatch::new( path.span, - generics.own_params.len(), + generics.count(), path.refine.len(), )); } - for (expr, param) in iter::zip(path.refine, &generics.own_params) { + let args = self.infcx.path_args(path.fhir_id); + for (i, expr) in path.refine.iter().enumerate() { + let Ok(param) = generics.param_at(i, genv).emit(&self.errors) else { return }; + let param = param.instantiate(genv.tcx(), &args, &[]); self.infcx .check_expr(expr, ¶m.sort) .collect_err(&mut self.errors); } - } + }; fhir::visit::walk_path(self, path); } } @@ -471,6 +467,10 @@ impl<'genv, 'tcx> ConvPhase<'genv, 'tcx> for Wf<'_, 'genv, 'tcx> { self.infcx.insert_sort_for_bty(fhir_id, sort); } + fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) { + self.infcx.insert_path_args(fhir_id, args); + } + fn insert_alias_reft_sort(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) { self.infcx.insert_sort_for_alias_reft(fhir_id, fsort); } diff --git a/crates/flux-fhir-analysis/src/wf/sortck.rs b/crates/flux-fhir-analysis/src/wf/sortck.rs index b9a79409c3..287325a93c 100644 --- a/crates/flux-fhir-analysis/src/wf/sortck.rs +++ b/crates/flux-fhir-analysis/src/wf/sortck.rs @@ -31,6 +31,7 @@ pub(super) struct InferCtxt<'genv, 'tcx> { num_unification_table: InPlaceUnificationTable, bv_size_unification_table: InPlaceUnificationTable, sort_of_bty: FxHashMap, + path_args: UnordMap, sort_of_alias_reft: UnordMap, } @@ -47,6 +48,7 @@ impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> { num_unification_table: InPlaceUnificationTable::new(), bv_size_unification_table: InPlaceUnificationTable::new(), sort_of_bty: Default::default(), + path_args: Default::default(), sort_of_alias_reft: Default::default(), } } @@ -452,6 +454,17 @@ impl<'genv, 'tcx> InferCtxt<'genv, 'tcx> { .clone() } + pub(crate) fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) { + self.path_args.insert(fhir_id, args); + } + + pub(crate) fn path_args(&self, fhir_id: FhirId) -> rty::GenericArgs { + self.path_args + .get(&fhir_id) + .unwrap_or_else(|| tracked_span_bug!("no entry found for `{fhir_id:?}`")) + .clone() + } + pub(crate) fn insert_sort_for_alias_reft(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) { self.sort_of_alias_reft.insert(fhir_id, fsort); } diff --git a/crates/flux-infer/src/infer.rs b/crates/flux-infer/src/infer.rs index 6f754254e1..fca6fb85ed 100644 --- a/crates/flux-infer/src/infer.rs +++ b/crates/flux-infer/src/infer.rs @@ -164,8 +164,13 @@ impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> { InferCtxtAt { infcx: self, span } } - pub fn instantiate_refine_args(&mut self, callee_def_id: DefId) -> InferResult> { + pub fn instantiate_refine_args( + &mut self, + callee_def_id: DefId, + args: &[rty::GenericArg], + ) -> InferResult> { Ok(RefineArgs::for_item(self.genv, callee_def_id, |param, _| { + let param = param.instantiate(self.genv.tcx(), args, &[]); self.fresh_infer_var(¶m.sort, param.mode) })?) } diff --git a/crates/flux-infer/src/refine_tree.rs b/crates/flux-infer/src/refine_tree.rs index 55950d15ee..4b9175e608 100644 --- a/crates/flux-infer/src/refine_tree.rs +++ b/crates/flux-infer/src/refine_tree.rs @@ -14,7 +14,7 @@ use flux_middle::{ canonicalize::{Hoister, HoisterDelegate}, evars::EVarSol, fold::{TypeFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitor}, - BaseTy, EarlyBinder, EarlyReftParam, Expr, GenericArgs, Name, Sort, SpecFuncDefns, Ty, + BaseTy, EarlyReftParam, Expr, GenericArgs, Name, RefineParam, Sort, SpecFuncDefns, Ty, TyCtor, TyKind, Var, }, }; @@ -201,33 +201,28 @@ impl RefineTree { args: Option<&GenericArgs>, ) -> QueryResult { let generics = genv.generics_of(def_id)?; - let reft_generics = genv.refinement_generics_of(def_id)?; - - let params: Vec<(Var, Sort)> = itertools::chain( - generics - .const_params(genv)? - .into_iter() - .map(|(pcst, sort)| Ok((Var::ConstGeneric(pcst), sort))), - (0..reft_generics.count()).map(|i| { - let param = reft_generics.param_at(i, genv)?; - let var = Var::EarlyParam(EarlyReftParam { index: i as u32, name: param.name }); - Ok((var, param.sort)) - }), - ) - .collect::>()?; - - // We have `generic_args` when want to instantiate a generic trait method at a particular - // impl's type, e.g. when doing impl-trait subtyping. - let params = if let Some(generic_args) = args { - params - .iter() - .map(|(var, sort)| { - (*var, EarlyBinder(sort.clone()).instantiate(genv.tcx(), generic_args, &[])) - }) - .collect() - } else { - params - }; + + let mut params = generics + .const_params(genv)? + .into_iter() + .map(|(pcst, sort)| (Var::ConstGeneric(pcst), sort)) + .collect_vec(); + let offset = params.len(); + genv.refinement_generics_of(def_id)?.fill_item( + genv, + &mut params, + &mut |param, index| { + let index = (index - offset) as u32; + let param: RefineParam = if let Some(args) = args { + param.instantiate(genv.tcx(), args, &[]) + } else { + param.instantiate_identity() + }; + let var = Var::EarlyParam(EarlyReftParam { index, name: param.name }); + (var, param.sort) + }, + )?; + let root = Node { kind: NodeKind::Root(params), nbindings: 0, parent: None, children: vec![] }; let root = NodePtr(Rc::new(RefCell::new(root))); @@ -721,7 +716,8 @@ mod pretty { let n = node.borrow(); match &n.kind { NodeKind::Root(bindings) => { - for (name, sort) in bindings { + // We reverse here because is reversed again at the end + for (name, sort) in bindings.iter().rev() { elements.push(format_cx!(cx, "{:?} {:?}", ^name, sort)); } } diff --git a/crates/flux-metadata/src/lib.rs b/crates/flux-metadata/src/lib.rs index 10777f58b2..46efa11eea 100644 --- a/crates/flux-metadata/src/lib.rs +++ b/crates/flux-metadata/src/lib.rs @@ -103,7 +103,7 @@ impl Key for (DefId, Symbol) { #[derive(TyEncodable, TyDecodable)] pub struct Tables { generics_of: UnordMap>, - refinement_generics_of: UnordMap>, + refinement_generics_of: UnordMap>>, predicates_of: UnordMap>>, item_bounds: UnordMap>>, assoc_refinements_of: UnordMap>, @@ -208,7 +208,10 @@ impl CrateStore for CStore { get!(self, generics_of, def_id) } - fn refinement_generics_of(&self, def_id: DefId) -> OptResult { + fn refinement_generics_of( + &self, + def_id: DefId, + ) -> OptResult> { get!(self, refinement_generics_of, def_id) } diff --git a/crates/flux-middle/src/cstore.rs b/crates/flux-middle/src/cstore.rs index 705d0a1cbe..6627193001 100644 --- a/crates/flux-middle/src/cstore.rs +++ b/crates/flux-middle/src/cstore.rs @@ -9,7 +9,10 @@ pub trait CrateStore { fn adt_def(&self, def_id: DefId) -> OptResult; fn adt_sort_def(&self, def_id: DefId) -> OptResult; fn generics_of(&self, def_id: DefId) -> OptResult; - fn refinement_generics_of(&self, def_id: DefId) -> OptResult; + fn refinement_generics_of( + &self, + def_id: DefId, + ) -> OptResult>; fn item_bounds(&self, def_id: DefId) -> OptResult>; fn predicates_of(&self, def_id: DefId) -> OptResult>; fn assoc_refinements_of(&self, def_id: DefId) -> OptResult; diff --git a/crates/flux-middle/src/fhir.rs b/crates/flux-middle/src/fhir.rs index 8aeeef86a5..a5519fcabd 100644 --- a/crates/flux-middle/src/fhir.rs +++ b/crates/flux-middle/src/fhir.rs @@ -661,6 +661,7 @@ pub enum QPath<'fhir> { #[derive(Clone, Copy)] pub struct Path<'fhir> { pub res: Res, + pub fhir_id: FhirId, pub segments: &'fhir [PathSegment<'fhir>], pub refine: &'fhir [Expr<'fhir>], pub span: Span, diff --git a/crates/flux-middle/src/fhir/lift.rs b/crates/flux-middle/src/fhir/lift.rs index 50eeeb7df3..18b346bc9b 100644 --- a/crates/flux-middle/src/fhir/lift.rs +++ b/crates/flux-middle/src/fhir/lift.rs @@ -397,7 +397,7 @@ impl<'a, 'genv, 'tcx> LiftCtxt<'a, 'genv, 'tcx> { let segments = try_alloc_slice!(self.genv, path.segments, |segment| self.lift_path_segment(segment))?; - Ok(fhir::Path { res, segments, refine: &[], span: path.span }) + Ok(fhir::Path { res, fhir_id: self.next_fhir_id(), segments, refine: &[], span: path.span }) } fn lift_path_segment( diff --git a/crates/flux-middle/src/global_env.rs b/crates/flux-middle/src/global_env.rs index 55723456fe..fcda5e7133 100644 --- a/crates/flux-middle/src/global_env.rs +++ b/crates/flux-middle/src/global_env.rs @@ -246,7 +246,7 @@ impl<'genv, 'tcx> GlobalEnv<'genv, 'tcx> { pub fn refinement_generics_of( self, def_id: impl IntoQueryParam, - ) -> QueryResult { + ) -> QueryResult> { self.inner .queries .refinement_generics_of(self, def_id.into_query_param()) diff --git a/crates/flux-middle/src/queries.rs b/crates/flux-middle/src/queries.rs index 107203009c..27d9587407 100644 --- a/crates/flux-middle/src/queries.rs +++ b/crates/flux-middle/src/queries.rs @@ -142,7 +142,8 @@ pub struct Providers { ) -> QueryResult>>, pub fn_sig: fn(GlobalEnv, LocalDefId) -> QueryResult>, pub generics_of: fn(GlobalEnv, LocalDefId) -> QueryResult, - pub refinement_generics_of: fn(GlobalEnv, LocalDefId) -> QueryResult, + pub refinement_generics_of: + fn(GlobalEnv, LocalDefId) -> QueryResult>, pub predicates_of: fn(GlobalEnv, LocalDefId) -> QueryResult>, pub assoc_refinements_of: fn(GlobalEnv, LocalDefId) -> QueryResult, @@ -209,7 +210,7 @@ pub struct Queries<'genv, 'tcx> { adt_def: Cache>, constant_info: Cache>, generics_of: Cache>, - refinement_generics_of: Cache>, + refinement_generics_of: Cache>>, predicates_of: Cache>>, assoc_refinements_of: Cache>, assoc_refinement_def: Cache<(DefId, Symbol), QueryResult>>, @@ -489,7 +490,7 @@ impl<'genv, 'tcx> Queries<'genv, 'tcx> { &self, genv: GlobalEnv, def_id: DefId, - ) -> QueryResult { + ) -> QueryResult> { run_with_cache(&self.refinement_generics_of, def_id, || { dispatch_query( genv, @@ -498,11 +499,11 @@ impl<'genv, 'tcx> Queries<'genv, 'tcx> { |def_id| genv.cstore().refinement_generics_of(def_id), |def_id| { let parent = genv.tcx().generics_of(def_id).parent; - Ok(rty::RefinementGenerics { + Ok(rty::EarlyBinder(rty::RefinementGenerics { parent, parent_count: 0, own_params: List::empty(), - }) + })) }, ) }) diff --git a/crates/flux-middle/src/rty/binder.rs b/crates/flux-middle/src/rty/binder.rs index 4e873d7f31..8e2b0ba351 100644 --- a/crates/flux-middle/src/rty/binder.rs +++ b/crates/flux-middle/src/rty/binder.rs @@ -16,7 +16,7 @@ use rustc_span::Symbol; use super::{ fold::TypeFoldable, subst::{self, BoundVarReplacer, FnMutDelegate}, - Expr, GenericArg, InferMode, Sort, + Expr, GenericArg, InferMode, RefineParam, Sort, }; #[derive(Clone, Debug, TyEncodable, TyDecodable)] @@ -66,6 +66,12 @@ impl EarlyBinder { } } +impl EarlyBinder { + pub fn name(&self) -> Symbol { + self.skip_binder_ref().name + } +} + #[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)] pub struct Binder { vars: List, diff --git a/crates/flux-middle/src/rty/mod.rs b/crates/flux-middle/src/rty/mod.rs index aa9756f913..4754f7b1f0 100644 --- a/crates/flux-middle/src/rty/mod.rs +++ b/crates/flux-middle/src/rty/mod.rs @@ -220,7 +220,9 @@ pub struct RefinementGenerics { pub own_params: List, } -#[derive(PartialEq, Eq, Debug, Clone, Hash, TyEncodable, TyDecodable)] +#[derive( + PartialEq, Eq, Debug, Clone, Hash, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable, +)] pub struct RefineParam { pub sort: Sort, pub name: Symbol, @@ -1211,7 +1213,7 @@ impl Ty { let def_id = genv.tcx().require_lang_item(LangItem::OwnedBox, None); let adt_def = genv.adt_def(def_id)?; - let args = vec![GenericArg::Ty(deref_ty), GenericArg::Ty(alloc_ty)]; + let args = List::from_arr([GenericArg::Ty(deref_ty), GenericArg::Ty(alloc_ty)]); let bty = BaseTy::adt(adt_def, args); Ok(Ty::indexed(bty, Expr::unit_adt(def_id))) @@ -1457,8 +1459,8 @@ impl BaseTy { BaseTy::Alias(AliasKind::Projection, alias_ty) } - pub fn adt(adt_def: AdtDef, args: impl Into) -> BaseTy { - BaseTy::Adt(adt_def, args.into()) + pub fn adt(adt_def: AdtDef, args: GenericArgs) -> BaseTy { + BaseTy::Adt(adt_def, args) } pub fn fn_def(def_id: DefId, args: impl Into) -> BaseTy { @@ -1757,41 +1759,21 @@ pub type RefineArgs = List; #[extension(pub trait RefineArgsExt)] impl RefineArgs { fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult { - Self::for_item(genv, def_id, |param, exprs| { - let index = exprs.len() as u32; - Expr::var(Var::EarlyParam(EarlyReftParam { index, name: param.name })) + Self::for_item(genv, def_id, |param, index| { + Expr::var(Var::EarlyParam(EarlyReftParam { index: index as u32, name: param.name() })) }) } fn for_item(genv: GlobalEnv, def_id: DefId, mut mk: F) -> QueryResult where - F: FnMut(&RefineParam, &[Expr]) -> Expr, + F: FnMut(EarlyBinder, usize) -> Expr, { let reft_generics = genv.refinement_generics_of(def_id)?; let count = reft_generics.count(); let mut args = Vec::with_capacity(count); - Self::fill_item(genv, &mut args, &reft_generics, &mut mk)?; + reft_generics.fill_item(genv, &mut args, &mut mk)?; Ok(List::from_vec(args)) } - - fn fill_item( - genv: GlobalEnv, - args: &mut Vec, - reft_generics: &RefinementGenerics, - mk: &mut F, - ) -> QueryResult<()> - where - F: FnMut(&RefineParam, &[Expr]) -> Expr, - { - if let Some(def_id) = reft_generics.parent { - let parent_generics = genv.refinement_generics_of(def_id)?; - Self::fill_item(genv, args, &parent_generics, mk)?; - } - for param in &reft_generics.own_params { - args.push(mk(param, args)); - } - Ok(()) - } } /// A type constructor meant to be used as generic a argument of [kind base]. This is just an alias @@ -2163,14 +2145,8 @@ impl RefinementGenerics { self.parent_count + self.own_params.len() } - pub fn param_at(&self, param_index: usize, genv: GlobalEnv) -> QueryResult { - if let Some(index) = param_index.checked_sub(self.parent_count) { - Ok(self.own_params[index].clone()) - } else { - let parent = self.parent.expect("parent_count > 0 but no parent?"); - genv.refinement_generics_of(parent)? - .param_at(param_index, genv) - } + pub fn own_count(&self) -> usize { + self.own_params.len() } // /// Iterate and collect all parameters in this item including parents @@ -2188,6 +2164,64 @@ impl RefinementGenerics { // } } +impl EarlyBinder { + pub fn parent(&self) -> Option { + self.skip_binder_ref().parent + } + + pub fn parent_count(&self) -> usize { + self.skip_binder_ref().parent_count + } + + pub fn count(&self) -> usize { + self.skip_binder_ref().count() + } + + pub fn own_count(&self) -> usize { + self.skip_binder_ref().own_count() + } + + pub fn own_param_at(&self, index: usize) -> EarlyBinder { + self.as_ref().map(|this| this.own_params[index].clone()) + } + + pub fn param_at( + &self, + param_index: usize, + genv: GlobalEnv, + ) -> QueryResult> { + if let Some(index) = param_index.checked_sub(self.parent_count()) { + Ok(self.own_param_at(index)) + } else { + let parent = self.parent().expect("parent_count > 0 but no parent?"); + genv.refinement_generics_of(parent)? + .param_at(param_index, genv) + } + } + + pub fn iter_own_params(&self) -> impl Iterator> + use<'_> { + self.skip_binder_ref() + .own_params + .iter() + .cloned() + .map(EarlyBinder) + } + + pub fn fill_item(&self, genv: GlobalEnv, vec: &mut Vec, mk: &mut F) -> QueryResult + where + F: FnMut(EarlyBinder, usize) -> R, + { + if let Some(def_id) = self.parent() { + genv.refinement_generics_of(def_id)? + .fill_item(genv, vec, mk)?; + } + for param in self.iter_own_params() { + vec.push(mk(param, vec.len())); + } + Ok(()) + } +} + impl EarlyBinder { pub fn predicates(&self) -> EarlyBinder> { EarlyBinder(self.0.predicates.clone()) diff --git a/crates/flux-middle/src/rty/refining.rs b/crates/flux-middle/src/rty/refining.rs index f278137eca..9504169665 100644 --- a/crates/flux-middle/src/rty/refining.rs +++ b/crates/flux-middle/src/rty/refining.rs @@ -320,7 +320,8 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> { let refine_args = if let ty::AliasKind::Opaque = alias_kind { rty::RefineArgs::for_item(self.genv, def_id, |param, _| { - rty::Expr::hole(rty::HoleKind::Expr(param.sort.clone())) + let param = param.instantiate(self.genv.tcx(), &args, &[]); + rty::Expr::hole(rty::HoleKind::Expr(param.sort)) })? } else { List::empty() diff --git a/crates/flux-refineck/src/checker.rs b/crates/flux-refineck/src/checker.rs index 71e53fa713..eb46024128 100644 --- a/crates/flux-refineck/src/checker.rs +++ b/crates/flux-refineck/src/checker.rs @@ -263,7 +263,7 @@ fn check_fn_subtyping( // 2. Fresh names for `T_f` refine-params / Instantiate fn_def_sig and normalize it infcx.push_scope(); - let refine_args = infcx.instantiate_refine_args(*def_id)?; + let refine_args = infcx.instantiate_refine_args(*def_id, sub_args)?; let sub_sig = sub_sig.instantiate(tcx, sub_args, &refine_args); let sub_sig = sub_sig .replace_bound_vars(|_| rty::ReErased, |sort, mode| infcx.fresh_infer_var(sort, mode)) @@ -777,7 +777,7 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> { let refine_args = match callee_def_id { Some(callee_def_id) => { infcx - .instantiate_refine_args(callee_def_id) + .instantiate_refine_args(callee_def_id, &generic_args) .with_span(span)? } None => rty::List::empty(), diff --git a/tests/tests/neg/error_messages/wf/invalid_generic_arguments.rs b/tests/tests/neg/error_messages/wf/invalid_generic_arguments.rs new file mode 100644 index 0000000000..2c807f2276 --- /dev/null +++ b/tests/tests/neg/error_messages/wf/invalid_generic_arguments.rs @@ -0,0 +1,17 @@ +#[flux::sig(fn(x: i32))] //~ ERROR generic arguments are not allowed on builtin type +fn test00() {} + +#[flux::sig(fn(x: T))] //~ ERROR generic arguments are not allowed on type parameter +fn test01() {} + +struct S; + +impl S { + #[flux::sig(fn(Self))] //~ ERROR generic arguments are not allowed on self type + fn test02() {} +} + +trait MyTrait { + #[flux::sig(fn(Self))] //~ ERROR generic arguments are not allowed on self type + fn test03(); +} diff --git a/tests/tests/neg/error_messages/wf/refinement_args.rs b/tests/tests/neg/error_messages/wf/refinement_args.rs new file mode 100644 index 0000000000..8d9edae942 --- /dev/null +++ b/tests/tests/neg/error_messages/wf/refinement_args.rs @@ -0,0 +1,5 @@ +#[flux::alias(type A(n: int) = i32{v: v >= n})] +type A = i32; + +#[flux::sig(fn(A(false)))] //~ ERROR mismatched sorts +fn foo(x: A) {}