Skip to content

Commit

Permalink
Put RefinementGenerics under EarlyBinder (#948)
Browse files Browse the repository at this point in the history
  • Loading branch information
nilehmann authored Dec 16, 2024
1 parent f440fa9 commit c6ad442
Show file tree
Hide file tree
Showing 20 changed files with 292 additions and 138 deletions.
2 changes: 2 additions & 0 deletions crates/flux-desugar/src/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ impl<'a, 'genv, 'tcx: 'genv> RustItemCtxt<'a, 'genv, 'tcx> {
let res = Res::Def(def_kind, def_id);
fhir::Path {
span,
fhir_id: self.next_fhir_id(),
res,
segments: self.genv.alloc_slice_fill_iter([fhir::PathSegment {
ident: surface::Ident::new(lang_item.name(), span),
Expand Down Expand Up @@ -1145,6 +1146,7 @@ trait DesugarCtxt<'genv, 'tcx: 'genv> {
let proj_start = path.segments.len() - unresolved_segments;
let fhir_path = fhir::Path {
res: partial_res.base_res(),
fhir_id: self.next_fhir_id(),
segments: try_alloc_slice!(self.genv(), &path.segments[..proj_start], |segment| {
self.desugar_path_segment(segment)
})?,
Expand Down
10 changes: 9 additions & 1 deletion crates/flux-fhir-analysis/locales/en-US.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ fhir_analysis_incorrect_generics_on_sort =
*[other] {$expected} generic arguments
} on sort
fhir_analysis_generics_on_type_parameter =
fhir_analysis_generics_on_sort_ty_param=
type parameter expects no generics but found {$found}
.label = found generics on sort type parameter
Expand All @@ -258,6 +258,14 @@ fhir_analysis_generics_on_opaque_sort =
fhir_analysis_refined_unrefinable_type =
type cannot be refined
fhir_analysis_generics_on_prim_ty =
generic arguments are not allowed on builtin type `{$name}`
fhir_analysis_generics_on_ty_param =
generic arguments are not allowed on type parameter `{$name}`
fhir_analysis_generics_on_self_ty =
generic arguments are not allowed on self type
# Check impl against trait errors

Expand Down
136 changes: 97 additions & 39 deletions crates/flux-fhir-analysis/src/conv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use itertools::Itertools;
use rustc_data_structures::fx::FxIndexMap;
use rustc_errors::Diagnostic;
use rustc_hash::FxHashSet;
use rustc_hir::{def::DefKind, def_id::DefId, OwnerId, PrimTy, Safety};
use rustc_hir::{def::DefKind, def_id::DefId, OwnerId, Safety};
use rustc_middle::{
middle::resolve_bound_vars::ResolvedArg,
ty::{self, AssocItem, AssocKind, BoundRegionKind::BrNamed, BoundVar, TyCtxt},
Expand Down Expand Up @@ -90,6 +90,10 @@ pub trait ConvPhase<'genv, 'tcx>: Sized {
/// during the first phase to collect the sort of base types.
fn insert_bty_sort(&mut self, fhir_id: FhirId, sort: rty::Sort);

/// Called after converting an path with the generic arguments. Using during the first phase
/// to instantiate sort of generic refinements.
fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs);

/// Called after converting an [`fhir::ExprKind::Alias`] with the sort of the resulting
/// [`rty::AliasReft`]. Used during the first phase to collect the sorts of refinement aliases.
fn insert_alias_reft_sort(&mut self, fhir_id: FhirId, fsort: rty::FuncSort);
Expand Down Expand Up @@ -160,6 +164,8 @@ impl<'genv, 'tcx> ConvPhase<'genv, 'tcx> for AfterSortck<'_, 'genv, 'tcx> {

fn insert_bty_sort(&mut self, _: FhirId, _: rty::Sort) {}

fn insert_path_args(&mut self, _: FhirId, _: rty::GenericArgs) {}

fn insert_alias_reft_sort(&mut self, _: FhirId, _: rty::FuncSort) {}
}

Expand Down Expand Up @@ -727,7 +733,7 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt<P> {
fhir::SortRes::SortParam(n) => return Ok(rty::Sort::Var(rty::ParamSort::from(n))),
fhir::SortRes::TyParam(def_id) => {
if !path.args.is_empty() {
let err = errors::GenericsOnTyParam::new(
let err = errors::GenericsOnSortTyParam::new(
path.segments.last().unwrap().span,
path.args.len(),
);
Expand Down Expand Up @@ -811,7 +817,13 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt<P> {
prim_sort: fhir::PrimSort,
) -> QueryResult {
if path.args.len() != prim_sort.generics() {
Err(emit_prim_sort_generics_error(self.genv(), path, prim_sort))?;
let err = errors::GenericsOnPrimitiveSort::new(
path.segments.last().unwrap().span,
prim_sort.name_str(),
path.args.len(),
prim_sort.generics(),
);
Err(self.emit(err))?;
}
Ok(())
}
Expand Down Expand Up @@ -1474,17 +1486,9 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt<P> {
path: &fhir::Path,
) -> QueryResult<rty::TyOrCtor> {
let bty = match path.res {
fhir::Res::PrimTy(PrimTy::Bool) => rty::BaseTy::Bool,
fhir::Res::PrimTy(PrimTy::Str) => rty::BaseTy::Str,
fhir::Res::PrimTy(PrimTy::Char) => rty::BaseTy::Char,
fhir::Res::PrimTy(PrimTy::Int(int_ty)) => {
rty::BaseTy::Int(rustc_middle::ty::int_ty(int_ty))
}
fhir::Res::PrimTy(PrimTy::Uint(uint_ty)) => {
rty::BaseTy::Uint(rustc_middle::ty::uint_ty(uint_ty))
}
fhir::Res::PrimTy(PrimTy::Float(float_ty)) => {
rty::BaseTy::Float(rustc_middle::ty::float_ty(float_ty))
fhir::Res::PrimTy(prim_ty) => {
self.check_prim_ty_generics(path, prim_ty)?;
prim_ty_to_bty(prim_ty)
}
fhir::Res::Def(DefKind::Struct | DefKind::Enum | DefKind::Union, did) => {
let adt_def = self.genv().adt_def(did)?;
Expand All @@ -1494,6 +1498,7 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt<P> {
fhir::Res::Def(DefKind::TyParam, def_id) => {
let owner_id = ty_param_owner(self.genv(), def_id);
let param_ty = def_id_to_param_ty(self.genv(), def_id);
self.check_ty_param_generics(path, param_ty)?;
let param = self
.genv()
.generics_of(owner_id)?
Expand All @@ -1507,6 +1512,7 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt<P> {
}
}
fhir::Res::SelfTyParam { trait_ } => {
self.check_self_ty_generics(path)?;
let param = &self.genv().generics_of(trait_)?.own_params[0];
match param.kind {
rty::GenericParamDefKind::Type { .. } => {
Expand All @@ -1517,6 +1523,7 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt<P> {
}
}
fhir::Res::SelfTyAlias { alias_to, .. } => {
self.check_self_ty_generics(path)?;
if P::EXPAND_TYPE_ALIASES {
return Ok(self.genv().type_of(alias_to)?.instantiate_identity());
} else {
Expand Down Expand Up @@ -1553,6 +1560,7 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt<P> {
}
fhir::Res::Def(DefKind::TyAlias, def_id) => {
let args = self.conv_generic_args(env, def_id, path.last_segment())?;
self.0.insert_path_args(path.fhir_id, args.clone());
let refine_args = path
.refine
.iter()
Expand All @@ -1567,11 +1575,7 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt<P> {
} else {
rty::BaseTy::Alias(
rty::AliasKind::Weak,
rty::AliasTy {
def_id,
args: List::from(args),
refine_args: List::from(refine_args),
},
rty::AliasTy { def_id, args, refine_args: List::from(refine_args) },
)
}
}
Expand Down Expand Up @@ -1606,10 +1610,10 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt<P> {
env: &mut Env,
def_id: DefId,
segment: &fhir::PathSegment,
) -> QueryResult<Vec<rty::GenericArg>> {
) -> QueryResult<List<rty::GenericArg>> {
let mut into = vec![];
self.conv_generic_args_into(env, def_id, segment, &mut into)?;
Ok(into)
Ok(List::from(into))
}

fn conv_generic_args_into(
Expand Down Expand Up @@ -1746,6 +1750,51 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt<P> {
fn emit(&self, err: impl Diagnostic<'genv>) -> ErrorGuaranteed {
self.genv().sess().emit_err(err)
}

fn check_prim_ty_generics(
&mut self,
path: &fhir::Path<'_>,
prim_ty: rustc_hir::PrimTy,
) -> QueryResult {
if !path.last_segment().args.is_empty() {
let err = errors::GenericsOnPrimTy { span: path.span, name: prim_ty.name_str() };
Err(self.emit(err))?;
}
Ok(())
}

fn check_ty_param_generics(
&mut self,
path: &fhir::Path<'_>,
param_ty: rty::ParamTy,
) -> QueryResult {
if !path.last_segment().args.is_empty() {
let err = errors::GenericsOnTyParam { span: path.span, name: param_ty.name };
Err(self.emit(err))?;
}
Ok(())
}

fn check_self_ty_generics(&mut self, path: &fhir::Path<'_>) -> QueryResult {
if !path.last_segment().args.is_empty() {
let err = errors::GenericsOnSelfTy { span: path.span };
Err(self.emit(err))?;
}
Ok(())
}
}

fn prim_ty_to_bty(prim_ty: rustc_hir::PrimTy) -> rty::BaseTy {
match prim_ty {
rustc_hir::PrimTy::Int(int_ty) => rty::BaseTy::Int(rustc_middle::ty::int_ty(int_ty)),
rustc_hir::PrimTy::Uint(uint_ty) => rty::BaseTy::Uint(rustc_middle::ty::uint_ty(uint_ty)),
rustc_hir::PrimTy::Float(float_ty) => {
rty::BaseTy::Float(rustc_middle::ty::float_ty(float_ty))
}
rustc_hir::PrimTy::Str => rty::BaseTy::Str,
rustc_hir::PrimTy::Bool => rty::BaseTy::Bool,
rustc_hir::PrimTy::Char => rty::BaseTy::Char,
}
}

/// Conversion of expressions
Expand Down Expand Up @@ -2177,20 +2226,6 @@ pub fn conv_func_decl(genv: GlobalEnv, func: &fhir::SpecFunc) -> QueryResult<rty
Ok(rty::SpecFuncDecl { name: func.name, sort, kind })
}

fn emit_prim_sort_generics_error(
genv: GlobalEnv,
path: &fhir::SortPath,
prim_sort: fhir::PrimSort,
) -> ErrorGuaranteed {
let err = errors::GenericsOnPrimitiveSort::new(
path.segments.last().unwrap().span,
prim_sort.name_str(),
path.args.len(),
prim_sort.generics(),
);
genv.sess().emit_err(err)
}

fn conv_lit(lit: fhir::Lit) -> rty::Constant {
match lit {
fhir::Lit::Int(n) => rty::Constant::from(n),
Expand Down Expand Up @@ -2242,7 +2277,7 @@ mod errors {
use flux_macros::Diagnostic;
use flux_middle::{fhir, global_env::GlobalEnv};
use rustc_hir::def_id::DefId;
use rustc_span::{symbol::Ident, Span};
use rustc_span::{symbol::Ident, Span, Symbol};

#[derive(Diagnostic)]
#[diag(fhir_analysis_assoc_type_not_found, code = E0999)]
Expand Down Expand Up @@ -2421,15 +2456,15 @@ mod errors {
}

#[derive(Diagnostic)]
#[diag(fhir_analysis_generics_on_type_parameter, code = E0999)]
pub(super) struct GenericsOnTyParam {
#[diag(fhir_analysis_generics_on_sort_ty_param, code = E0999)]
pub(super) struct GenericsOnSortTyParam {
#[primary_span]
#[label]
span: Span,
found: usize,
}

impl GenericsOnTyParam {
impl GenericsOnSortTyParam {
pub(super) fn new(span: Span, found: usize) -> Self {
Self { span, found }
}
Expand Down Expand Up @@ -2464,4 +2499,27 @@ mod errors {
Self { span, found }
}
}

#[derive(Diagnostic)]
#[diag(fhir_analysis_generics_on_prim_ty, code = E0999)]
pub(super) struct GenericsOnPrimTy {
#[primary_span]
pub span: Span,
pub name: &'static str,
}

#[derive(Diagnostic)]
#[diag(fhir_analysis_generics_on_ty_param, code = E0999)]
pub(super) struct GenericsOnTyParam {
#[primary_span]
pub span: Span,
pub name: Symbol,
}

#[derive(Diagnostic)]
#[diag(fhir_analysis_generics_on_self_ty, code = E0999)]
pub(super) struct GenericsOnSelfTy {
#[primary_span]
pub span: Span,
}
}
11 changes: 6 additions & 5 deletions crates/flux-fhir-analysis/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,11 @@ fn generics_of(genv: GlobalEnv, def_id: LocalDefId) -> QueryResult<rty::Generics
fn refinement_generics_of(
genv: GlobalEnv,
local_id: LocalDefId,
) -> QueryResult<rty::RefinementGenerics> {
) -> QueryResult<rty::EarlyBinder<rty::RefinementGenerics>> {
let parent = genv.tcx().generics_of(local_id).parent;
let parent_count =
if let Some(def_id) = parent { genv.refinement_generics_of(def_id)?.count() } else { 0 };
match genv.map().node(local_id)? {
let generics = match genv.map().node(local_id)? {
fhir::Node::Item(fhir::Item {
kind: fhir::ItemKind::Fn(..) | fhir::ItemKind::TyAlias(..),
generics,
Expand All @@ -420,10 +420,11 @@ fn refinement_generics_of(
}) => {
let wfckresults = genv.check_wf(local_id)?;
let params = conv::conv_refinement_generics(generics.refinement_params, &wfckresults)?;
Ok(rty::RefinementGenerics { parent, parent_count, own_params: params })
rty::RefinementGenerics { parent, parent_count, own_params: params }
}
_ => Ok(rty::RefinementGenerics { parent, parent_count, own_params: rty::List::empty() }),
}
_ => rty::RefinementGenerics { parent, parent_count, own_params: rty::List::empty() },
};
Ok(rty::EarlyBinder(generics))
}

fn type_of(genv: GlobalEnv, def_id: LocalDefId) -> QueryResult<rty::EarlyBinder<rty::TyOrCtor>> {
Expand Down
26 changes: 13 additions & 13 deletions crates/flux-fhir-analysis/src/wf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ mod errors;
mod param_usage;
mod sortck;

use std::iter;

use flux_common::result::{ErrorCollector, ResultExt as _};
use flux_errors::{Errors, FluxSession};
use flux_middle::{
Expand Down Expand Up @@ -377,31 +375,29 @@ impl<'genv> fhir::visit::Visitor<'genv> for Wf<'_, 'genv, '_> {
}

fn visit_path(&mut self, path: &fhir::Path<'genv>) {
let genv = self.genv();
if let fhir::Res::Def(DefKind::TyAlias, def_id) = path.res {
let Some(generics) = self
.infcx
.genv
.refinement_generics_of(def_id)
.emit(&self.errors)
.ok()
else {
let Ok(generics) = genv.refinement_generics_of(def_id).emit(&self.errors) else {
return;
};

if path.refine.len() != generics.own_params.len() {
if path.refine.len() != generics.count() {
self.errors.emit(errors::EarlyBoundArgCountMismatch::new(
path.span,
generics.own_params.len(),
generics.count(),
path.refine.len(),
));
}

for (expr, param) in iter::zip(path.refine, &generics.own_params) {
let args = self.infcx.path_args(path.fhir_id);
for (i, expr) in path.refine.iter().enumerate() {
let Ok(param) = generics.param_at(i, genv).emit(&self.errors) else { return };
let param = param.instantiate(genv.tcx(), &args, &[]);
self.infcx
.check_expr(expr, &param.sort)
.collect_err(&mut self.errors);
}
}
};
fhir::visit::walk_path(self, path);
}
}
Expand Down Expand Up @@ -471,6 +467,10 @@ impl<'genv, 'tcx> ConvPhase<'genv, 'tcx> for Wf<'_, 'genv, 'tcx> {
self.infcx.insert_sort_for_bty(fhir_id, sort);
}

fn insert_path_args(&mut self, fhir_id: FhirId, args: rty::GenericArgs) {
self.infcx.insert_path_args(fhir_id, args);
}

fn insert_alias_reft_sort(&mut self, fhir_id: FhirId, fsort: rty::FuncSort) {
self.infcx.insert_sort_for_alias_reft(fhir_id, fsort);
}
Expand Down
Loading

0 comments on commit c6ad442

Please sign in to comment.