Skip to content

Commit

Permalink
fix: union type bug (2)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Sep 15, 2024
1 parent 3b9bbdf commit 2f77a24
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 54 deletions.
18 changes: 17 additions & 1 deletion crates/erg_compiler/context/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,23 @@ impl Context {
(Or(ors), rhs) => ors.iter().any(|or| self.supertype_of(or, rhs)),
// Int :> (Nat or Str) == Int :> Nat && Int :> Str == false
(lhs, Or(ors)) => ors.iter().all(|or| self.supertype_of(lhs, or)),
(And(l), And(r)) => r.iter().any(|r| l.iter().all(|l| self.supertype_of(l, r))),
// Hash and Eq :> HashEq and ... == true
// Add(T) and Eq :> Add(Int) and Eq == true
(And(l), And(r)) => {
if r.iter().any(|r| l.iter().all(|l| self.supertype_of(l, r))) {
return true;
}
if l.len() == r.len() {
let mut r = r.clone();
for _ in 1..l.len() {
if l.iter().zip(&r).all(|(l, r)| self.supertype_of(l, r)) {
return true;
}
r.rotate_left(1);
}
}
false
}
// (Num and Show) :> Show == false
(And(ands), rhs) => ands.iter().all(|and| self.supertype_of(and, rhs)),
// Show :> (Num and Show) == true
Expand Down
8 changes: 4 additions & 4 deletions crates/erg_compiler/context/initialize/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,10 +588,10 @@ impl Context {
neg.register_builtin_erg_decl(OP_NEG, op_t, Visibility::BUILTIN_PUBLIC);
neg.register_builtin_erg_decl(OUTPUT, Type, Visibility::BUILTIN_PUBLIC);
/* Num */
let mut num = Self::builtin_mono_trait(NUM, 2);
num.register_superclass(poly(ADD, vec![]), &add);
num.register_superclass(poly(SUB, vec![]), &sub);
num.register_superclass(poly(MUL, vec![]), &mul);
let num = Self::builtin_mono_trait(NUM, 2);
// num.register_superclass(poly(ADD, vec![]), &add);
// num.register_superclass(poly(SUB, vec![]), &sub);
// num.register_superclass(poly(MUL, vec![]), &mul);
/* ToBool */
let mut to_bool = Self::builtin_mono_trait(TO_BOOL, 2);
let _Slf = mono_q(SELF, subtypeof(mono(TO_BOOL)));
Expand Down
17 changes: 13 additions & 4 deletions crates/erg_compiler/context/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1943,6 +1943,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
/// ```erg
/// unify(Int, Nat) == Some(Int)
/// unify(Int, Str) == None
/// unify(T, Never) == Some(T)
/// unify({1.2}, Nat) == Some(Float)
/// unify(Nat, Int!) == Some(Int)
/// unify(Eq, Int) == None
Expand All @@ -1951,13 +1952,21 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
/// ```
fn unify(&self, lhs: &Type, rhs: &Type) -> Option<Type> {
match (lhs, rhs) {
(Never, other) | (other, Never) => {
return Some(other.clone());
}
(Or(tys), other) | (other, Or(tys)) => {
let mut unified = Never;
for ty in tys {
if let Some(t) = self.unify(ty, other) {
return self.unify(&t, ty);
unified = self.ctx.union(&unified, &t);
}
}
return None;
if unified != Never {
return Some(unified);
} else {
return None;
}
}
(FreeVar(fv), _) if fv.is_linked() => return self.unify(&fv.crack(), rhs),
(_, FreeVar(fv)) if fv.is_linked() => return self.unify(lhs, &fv.crack()),
Expand All @@ -1981,11 +1990,11 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
let l_sups = self.ctx.get_super_classes(lhs)?;
let r_sups = self.ctx.get_super_classes(rhs)?;
for l_sup in l_sups {
if self.ctx.supertype_of(&l_sup, &Obj) {
if l_sup == Obj || self.ctx.is_trait(&l_sup) {
continue;
}
for r_sup in r_sups.clone() {
if self.ctx.supertype_of(&r_sup, &Obj) {
if r_sup == Obj || self.ctx.is_trait(&r_sup) {
continue;
}
if let Some(t) = self.ctx.max(&l_sup, &r_sup).either() {
Expand Down
62 changes: 25 additions & 37 deletions crates/erg_compiler/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,9 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
}
}

fn elem_err(&self, l: &Type, r: &Type, elem: &hir::Expr) -> LowerErrors {
fn elem_err(&self, union: Type, elem: &hir::Expr) -> LowerErrors {
let elem_disp_notype = elem.to_string_notype();
let l = self.module.context.readable_type(l.clone());
let r = self.module.context.readable_type(r.clone());
let union = self.module.context.readable_type(union);
LowerErrors::from(LowerError::syntax_error(
self.cfg.input.clone(),
line!() as usize,
Expand All @@ -379,10 +378,10 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
)
.to_owned(),
Some(switch_lang!(
"japanese" => format!("[..., {elem_disp_notype}: {l} or {r}]など明示的に型を指定してください"),
"simplified_chinese" => format!("请明确指定类型,例如: [..., {elem_disp_notype}: {l} or {r}]"),
"traditional_chinese" => format!("請明確指定類型,例如: [..., {elem_disp_notype}: {l} or {r}]"),
"english" => format!("please specify the type explicitly, e.g. [..., {elem_disp_notype}: {l} or {r}]"),
"japanese" => format!("[..., {elem_disp_notype}: {union}]など明示的に型を指定してください"),
"simplified_chinese" => format!("请明确指定类型,例如: [..., {elem_disp_notype}: {union}]"),
"traditional_chinese" => format!("請明確指定類型,例如: [..., {elem_disp_notype}: {union}]"),
"english" => format!("please specify the type explicitly, e.g. [..., {elem_disp_notype}: {union}]"),
)),
))
}
Expand Down Expand Up @@ -453,36 +452,25 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
union: &Type,
elem: &hir::Expr,
) -> LowerResult<()> {
if ERG_MODE && expect_elem.is_none() {
if let Some((l, r)) = union_.union_pair() {
match (l.is_unbound_var(), r.is_unbound_var()) {
// e.g. [1, "a"]
(false, false) => {
if let hir::Expr::TypeAsc(type_asc) = elem {
// e.g. [1, "a": Str or NoneType]
if !self
.module
.context
.supertype_of(&type_asc.spec.spec_t, union)
{
return Err(self.elem_err(&l, &r, elem));
} // else(OK): e.g. [1, "a": Str or Int]
}
// OK: ?T(:> {"a"}) or ?U(:> {"b"}) or {"c", "d"} => {"a", "b", "c", "d"} <: Str
else if self
.module
.context
.coerce(union_.derefine(), &())
.map_or(true, |coerced| coerced.union_pair().is_some())
{
return Err(self.elem_err(&l, &r, elem));
}
}
// TODO: check if the type is compatible with the other type
(true, false) => {}
(false, true) => {}
(true, true) => {}
}
if ERG_MODE && expect_elem.is_none() && union_.union_size() > 1 {
if let hir::Expr::TypeAsc(type_asc) = elem {
// e.g. [1, "a": Str or NoneType]
if !self
.module
.context
.supertype_of(&type_asc.spec.spec_t, union)
{
return Err(self.elem_err(union_.clone(), elem));
} // else(OK): e.g. [1, "a": Str or Int]
}
// OK: ?T(:> {"a"}) or ?U(:> {"b"}) or {"c", "d"} => {"a", "b", "c", "d"} <: Str
else if self
.module
.context
.coerce(union_.derefine(), &())
.map_or(true, |coerced| coerced.union_pair().is_some())
{
return Err(self.elem_err(union_.clone(), elem));
}
}
Ok(())
Expand Down
23 changes: 15 additions & 8 deletions crates/erg_compiler/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,9 @@ impl PartialEq for Type {
(Self::NamedTuple(lhs), Self::NamedTuple(rhs)) => lhs == rhs,
(Self::Refinement(l), Self::Refinement(r)) => l == r,
(Self::Quantified(l), Self::Quantified(r)) => l == r,
(Self::And(l), Self::And(r)) => l.iter().collect::<Set<_>>().linear_eq(&r.iter().collect()),
(Self::And(l), Self::And(r)) => {
l.iter().collect::<Set<_>>().linear_eq(&r.iter().collect())
}
(Self::Or(l), Self::Or(r)) => l.linear_eq(r),
(Self::Not(l), Self::Not(r)) => l == r,
(
Expand Down Expand Up @@ -1872,7 +1874,7 @@ impl BitAnd for Type {
r.push(l);
Self::And(r)
}
(l, r) => Self::checked_and(vec! {l, r}),
(l, r) => Self::checked_and(vec![l, r]),
}
}
}
Expand Down Expand Up @@ -3023,12 +3025,8 @@ impl Type {
Self::ProjCall { lhs, args, .. } => {
lhs.has_type_satisfies(f) || args.iter().any(|t| t.has_type_satisfies(f))
}
Self::And(tys) => {
tys.iter().any(|t| t.has_type_satisfies(f))
}
Self::Or(tys) => {
tys.iter().any(|t| t.has_type_satisfies(f))
}
Self::And(tys) => tys.iter().any(|t| t.has_type_satisfies(f)),
Self::Or(tys) => tys.iter().any(|t| t.has_type_satisfies(f)),
Self::Not(t) => t.has_type_satisfies(f),
Self::Ref(t) => t.has_type_satisfies(f),
Self::RefMut { before, after } => {
Expand Down Expand Up @@ -3396,6 +3394,15 @@ impl Type {
}
}

pub fn union_types(&self) -> Option<Set<Type>> {
match self {
Self::FreeVar(fv) if fv.is_linked() => fv.crack().union_types(),
Self::Refinement(refine) => refine.t.union_types(),
Self::Or(tys) => Some(tys.clone()),
_ => None,
}
}

/// assert!((A or B).contains_union(B))
pub fn contains_union(&self, typ: &Type) -> bool {
match self {
Expand Down

0 comments on commit 2f77a24

Please sign in to comment.