Skip to content

Commit

Permalink
fix: refinement type assert cast bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Feb 14, 2024
1 parent 1762588 commit 27ede53
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 6 deletions.
8 changes: 8 additions & 0 deletions crates/erg_common/triple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ impl<T> Triple<T, T> {
Triple::Ok(a) | Triple::Err(a) => Some(a),
}
}

pub fn merge_or(self, default: T) -> T {
match self {
Triple::None => default,
Triple::Ok(ok) => ok,
Triple::Err(err) => err,
}
}
}

impl<T, E: std::error::Error> Triple<T, E> {
Expand Down
37 changes: 31 additions & 6 deletions crates/erg_compiler/context/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,7 @@ impl Context {
/// union(Array(Int, 2), Array(Str, 3)) == Array(Int, 2) or Array(Int, 3)
/// union({ .a = Int }, { .a = Str }) == { .a = Int or Str }
/// union({ .a = Int }, { .a = Int; .b = Int }) == { .a = Int }
/// union((A and B) or C) == (A or C) and (B or C)
/// ```
pub(crate) fn union(&self, lhs: &Type, rhs: &Type) -> Type {
if lhs == rhs {
Expand Down Expand Up @@ -1345,6 +1346,16 @@ impl Context {
_ => self.simple_union(lhs, rhs),
},
(other, or @ Or(_, _)) | (or @ Or(_, _), other) => self.union_add(or, other),
// (A and B) or C ==> (A or C) and (B or C)
(and_t @ And(_, _), other) | (other, and_t @ And(_, _)) => {
let ands = and_t.ands();
let mut t = Type::Obj;
for branch in ands.iter() {
let union = self.union(branch, other);
t = and(t, union);
}
t
}
(t, Type::Never) | (Type::Never, t) => t.clone(),
// Array({1, 2}, 2), Array({3, 4}, 2) ==> Array({1, 2, 3, 4}, 2)
(
Expand Down Expand Up @@ -1497,12 +1508,6 @@ impl Context {
self.intersection(&fv.crack(), other)
}
(Refinement(l), Refinement(r)) => Type::Refinement(self.intersection_refinement(l, r)),
(other, Refinement(refine)) | (Refinement(refine), other) => {
let other = other.clone().into_refinement();
let intersec = self.intersection_refinement(&other, refine);
self.try_squash_refinement(intersec)
.unwrap_or_else(Type::Refinement)
}
(Structural(l), Structural(r)) => self.intersection(l, r).structuralize(),
(Guard(l), Guard(r)) => {
if l.namespace == r.namespace && l.target == r.target {
Expand All @@ -1527,6 +1532,26 @@ impl Context {
(other, and @ And(_, _)) | (and @ And(_, _), other) => {
self.intersection_add(and, other)
}
// (A or B) and C == (A and C) or (B and C)
(or_t @ Or(_, _), other) | (other, or_t @ Or(_, _)) => {
let ors = or_t.ors();
let mut t = Type::Never;
for branch in ors.iter() {
let isec = self.intersection(branch, other);
if branch.is_unbound_var() {
t = or(t, isec);
} else {
t = self.union(&t, &isec);
}
}
t
}
(other, Refinement(refine)) | (Refinement(refine), other) => {
let other = other.clone().into_refinement();
let intersec = self.intersection_refinement(&other, refine);
self.try_squash_refinement(intersec)
.unwrap_or_else(Type::Refinement)
}
// overloading
(l, r) if l.is_subr() && r.is_subr() => and(lhs.clone(), rhs.clone()),
_ => self.simple_intersection(lhs, rhs),
Expand Down
1 change: 1 addition & 0 deletions crates/erg_compiler/context/inquire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3657,6 +3657,7 @@ impl Context {
/// ```erg
/// recover_typarams(Int, Nat) == Nat
/// recover_typarams(Array!(Int, _), Array(Nat, 2)) == Array!(Nat, 2)
/// recover_typarams(Str or NoneType, {"a", "b"}) == {"a", "b"}
/// ```
/// ```erg
/// # REVIEW: should be?
Expand Down
8 changes: 8 additions & 0 deletions crates/erg_compiler/tests/infer.er
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,11 @@ c_new x, y = C.new x, y
C = Class Int
C.
new x, y = Self x + y

val!() =
for! [{ "a": "b" }], (pkg as {Str: Str}) =>
x = pkg.get("a", "c")
assert x in {"b"}
val!::return x
"d"
val = val!()
3 changes: 3 additions & 0 deletions crates/erg_compiler/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ fn _test_infer_types() -> Result<(), ()> {
let c_new_t = func2(add_r, r, c.clone()).quantify();
module.context.assert_var_type("c_new", &c_new_t)?;
module.context.assert_attr_type(&c, "new", &c_new_t)?;
module
.context
.assert_var_type("val", &v_enum(set! { "b".into(), "d".into() }))?;
Ok(())
}

Expand Down

0 comments on commit 27ede53

Please sign in to comment.