Skip to content

Commit

Permalink
add some wf checks for fn types
Browse files Browse the repository at this point in the history
  • Loading branch information
shua committed Jul 2, 2024
1 parent 8732bc3 commit 45e0534
Show file tree
Hide file tree
Showing 11 changed files with 218 additions and 73 deletions.
31 changes: 29 additions & 2 deletions crates/formality-prove/src/decls.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use formality_core::{set, Set, Upcast};
use formality_macros::term;
use formality_types::grammar::{
AdtId, AliasName, AliasTy, Binder, Parameter, Predicate, Relation, TraitId, TraitRef, Ty, Wc,
Wcs,
AdtId, AliasName, AliasTy, Binder, FnId, Parameter, Predicate, Relation, TraitId, TraitRef, Ty,
Wc, Wcs,
};

#[term]
Expand All @@ -16,6 +16,7 @@ pub struct Decls {
pub alias_eq_decls: Vec<AliasEqDecl>,
pub alias_bound_decls: Vec<AliasBoundDecl>,
pub adt_decls: Vec<AdtDecl>,
pub fn_decls: Vec<FnDecl>,
pub local_trait_ids: Set<TraitId>,
pub local_adt_ids: Set<AdtId>,
}
Expand Down Expand Up @@ -78,6 +79,13 @@ impl Decls {
v.pop().unwrap()
}

pub fn fn_decl(&self, fn_id: &FnId) -> &FnDecl {
let mut v: Vec<_> = self.fn_decls.iter().filter(|t| t.id == *fn_id).collect();
assert!(!v.is_empty(), "no fn named `{fn_id:?}`");
assert!(v.len() <= 1, "multiple fns named `{fn_id:?}`");
v.pop().unwrap()
}

/// Return the set of "trait invariants" for all traits.
/// See [`TraitDecl::trait_invariants`].
pub fn trait_invariants(&self) -> Set<TraitInvariant> {
Expand All @@ -96,6 +104,7 @@ impl Decls {
alias_eq_decls: vec![],
alias_bound_decls: vec![],
adt_decls: vec![],
fn_decls: vec![],
local_trait_ids: set![],
local_adt_ids: set![],
}
Expand Down Expand Up @@ -304,3 +313,21 @@ pub struct AdtDeclBoundData {
/// The where-clauses declared on the ADT,
pub where_clause: Wcs,
}

/// A "function declaration" declares a function name, its generics, its input and ouput types, and its where-clauses.
/// It doesn't currently capture the function body, or input argument names.
///
/// In Rust syntax, it covers the `fn foo<T, U>(_: T) -> U where T: Bar`
#[term(fn $id $binder)]
pub struct FnDecl {
pub id: FnId,
pub binder: Binder<FnDeclBoundData>,
}

/// The "bound data" for a [`FnDecl`][].
#[term(($input_tys) -> $output_ty $:where $where_clause)]
pub struct FnDeclBoundData {
pub input_tys: Vec<Ty>,
pub output_ty: Ty,
pub where_clause: Wcs,
}
22 changes: 22 additions & 0 deletions crates/formality-prove/src/prove/prove_wf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ judgment_fn! {
(prove_wf(decls, env, assumptions, RigidTy { name: RigidName::ScalarId(_), parameters }) => c)
)

(
// only checks that type is well-formed, does not do any lifetime or borrow check
(for_all(&decls, &env, &assumptions, &parameters, &prove_wf) => c)
--- ("ref")
(prove_wf(decls, env, assumptions, RigidTy { name: RigidName::Ref(_), parameters }) => c)
)

(
(for_all(&decls, &env, &assumptions, &parameters, &prove_wf) => c)
(let t = decls.adt_decl(&adt_id))
Expand All @@ -51,6 +58,21 @@ judgment_fn! {
(prove_wf(decls, env, assumptions, RigidTy { name: RigidName::AdtId(adt_id), parameters }) => c)
)

(
(for_all(&decls, &env, &assumptions, &parameters, &prove_wf) => c)
(let t = decls.fn_decl(&fn_id))
(let t = t.binder.instantiate_with(&parameters).unwrap())
(prove_after(&decls, c, &assumptions, t.where_clause) => c)
--- ("fn-defs")
(prove_wf(decls, env, assumptions, RigidTy { name: RigidName::FnDef(fn_id), parameters }) => c)
)

(
(for_all(&decls, &env, &assumptions, &parameters, &prove_wf) => c)
--- ("fn-ptr")
(prove_wf(decls, env, assumptions, RigidTy { name: RigidName::FnPtr(_), parameters }) => c)
)

(
(prove_wf(&decls, &env, &assumptions, ty) => c)
--- ("rigid constants")
Expand Down
36 changes: 33 additions & 3 deletions crates/formality-rust/src/prove.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::grammar::{
Adt, AdtBoundData, AssociatedTy, AssociatedTyBoundData, AssociatedTyValue,
AssociatedTyValueBoundData, Crate, CrateItem, ImplItem, NegTraitImpl, NegTraitImplBoundData,
Program, Trait, TraitBoundData, TraitImpl, TraitImplBoundData, TraitItem, WhereBound,
WhereBoundData, WhereClause, WhereClauseData,
AssociatedTyValueBoundData, Crate, CrateItem, Fn, FnBoundData, ImplItem, NegTraitImpl,
NegTraitImplBoundData, Program, Trait, TraitBoundData, TraitImpl, TraitImplBoundData,
TraitItem, WhereBound, WhereBoundData, WhereClause, WhereClauseData,
};
use formality_core::{seq, Set, To, Upcast, Upcasted};
use formality_prove as prove;
Expand All @@ -20,6 +20,7 @@ impl Program {
alias_eq_decls: self.alias_eq_decls(),
alias_bound_decls: self.alias_bound_decls(),
adt_decls: self.adt_decls(),
fn_decls: self.fn_decls(),
local_trait_ids: self.local_trait_ids(),
local_adt_ids: self.local_adt_ids(),
}
Expand Down Expand Up @@ -58,6 +59,10 @@ impl Program {
self.crates.iter().flat_map(|c| c.adt_decls()).collect()
}

fn fn_decls(&self) -> Vec<prove::FnDecl> {
self.crates.iter().flat_map(|c| c.fn_decls()).collect()
}

fn local_trait_ids(&self) -> Set<TraitId> {
self.crates
.last()
Expand Down Expand Up @@ -330,6 +335,31 @@ impl Crate {
})
.collect()
}

fn fn_decls(&self) -> Vec<prove::FnDecl> {
self.items
.iter()
.flat_map(|item| match item {
CrateItem::Fn(f) => Some(f),
_ => None,
})
.map(|Fn { id, binder }| prove::FnDecl {
id: id.clone(),
binder: binder.map(
|FnBoundData {
input_tys,
output_ty,
where_clauses,
body: _,
}| prove::FnDeclBoundData {
input_tys,
output_ty,
where_clause: where_clauses.to_wcs(),
},
),
})
.collect()
}
}

pub trait ToWcs {
Expand Down
8 changes: 8 additions & 0 deletions crates/formality-types/src/grammar/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ impl Ty {
}
.upcast()
}

pub fn unit() -> Ty {
RigidTy {
name: RigidName::Tuple(0),
parameters: vec![],
}
.upcast()
}
}

impl UpcastFrom<TyData> for Ty {
Expand Down
19 changes: 19 additions & 0 deletions crates/formality-types/src/grammar/ty/debug_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@ impl Debug for RigidTy {
write!(f, "()")
}
}
RigidName::FnDef(name) => {
let parameters = PrettyParameters::new("<", ">", parameters);
write!(f, "{name:?}{parameters:?}",)
}
RigidName::FnPtr(arity) if parameters.len() == *arity + 1 => {
let len = parameters.len();
if *arity != 0 {
write!(
f,
"{:?}",
PrettyParameters::new("(", ")", &parameters[..len - 1])
)?;
} else {
// PrettyParameters would skip the separators
// for 0 arity
write!(f, "()")?;
}
write!(f, "-> {:?}", parameters[len - 1])
}
_ => {
write!(f, "{:?}{:?}", name, PrettyParameters::angle(parameters))
}
Expand Down
33 changes: 32 additions & 1 deletion crates/formality-types/src/grammar/ty/parse_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use formality_core::Upcast;
use formality_core::{seq, Set};

use crate::grammar::{
AdtId, AssociatedItemId, Bool, ConstData, RefKind, RigidName, Scalar, TraitId,
AdtId, AssociatedItemId, Bool, ConstData, FnId, RefKind, RigidName, Scalar, TraitId,
};

use super::{AliasTy, AssociatedTyName, Lt, Parameter, ParameterKind, RigidTy, ScalarId, Ty};
Expand Down Expand Up @@ -70,6 +70,37 @@ impl CoreParse<Rust> for RigidTy {
parameters: types.upcast(),
})
});

parser.parse_variant("Fn", Precedence::default(), |p| {
// parses 'fn name<params>' as fn-def
// or 'fn(params) -> ty' as fn-ptr
p.expect_keyword("fn")?;

if p.expect_char('(').is_ok() {
p.reject_custom_keywords(&["alias", "rigid", "predicate"])?;
let mut types: Vec<Ty> = p.comma_nonterminal()?;
p.expect_char(')')?;
let name = RigidName::FnPtr(types.len());
if p.expect_char('-').is_ok() {
p.expect_char('>')?;
let ret = p.nonterminal()?;
types.push(ret);
} else {
types.push(Ty::unit());
}
Ok(RigidTy {
name,
parameters: types.upcast(),
})
} else {
let name: FnId = p.nonterminal()?;
let parameters: Vec<Parameter> = parse_parameters(p)?;
Ok(RigidTy {
name: RigidName::FnDef(name),
parameters,
})
}
})
})
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/formality-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ formality_core::declare_language! {
const BINDING_CLOSE = '>';
const KEYWORDS = [
"mut",
"fn",
"struct",
"enum",
"union",
Expand Down
Loading

0 comments on commit 45e0534

Please sign in to comment.