From 28ea0755343b2daa5fb914d19a8fb46aed182a86 Mon Sep 17 00:00:00 2001 From: Li Xiaoquan Date: Sat, 30 Mar 2019 11:42:27 +0800 Subject: [PATCH] [Relay] Add foldr1 --- python/tvm/relay/prelude.py | 24 ++++++++++++++++++++++++ src/relay/pass/kind_check.cc | 2 +- tests/python/relay/test_adt.py | 19 +++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 6cf104ab388a5..8157380ad9535 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -142,6 +142,29 @@ def define_list_foldr(self): self.mod[self.foldr] = Function([f, bv, av], Match(av, [nil_case, cons_case]), b, [a, b]) + def define_list_foldr1(self): + """Defines a right-way fold over an un-empty list. + + foldr1(f, l) : fn(fn(a, a) -> a, list[a]) -> a + + foldr1(f, cons(a1, cons(a2, cons(..., cons(an, nil))))) + evalutes to f(a1, f(a2, f(..., f(an-1, an)))...) + """ + self.foldr1 = GlobalVar("foldr1") + a = TypeVar("a") + f = Var("f", FuncType([a, a], a)) + av = Var("av", self.l(a)) + x = Var("x") + y = Var("y") + z = Var("z") + one_case = Clause(PatternConstructor(self.cons, + [PatternVar(x), PatternConstructor(self.nil)]), x) + cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), + f(y, self.foldr1(f, z))) + self.mod[self.foldr1] = Function([f, av], + Match(av, [one_case, cons_case]), a, [a]) + + def define_list_concat(self): """Defines a function that concatenates two lists. @@ -471,6 +494,7 @@ def __init__(self, mod): self.define_list_map() self.define_list_foldl() self.define_list_foldr() + self.define_list_foldr1() self.define_list_concat() self.define_list_filter() self.define_list_zip() diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index f1e539d71d48b..52c0afdd0c75a 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -144,7 +144,7 @@ struct KindChecker : TypeFunctor { for (const auto& con : op->constructors) { if (!con->belong_to.same_as(op->header)) { ReportFatalError(RELAY_ERROR(con << " has header " << con->belong_to - << " but " << op << "has header " << op->header)); + << " but " << op << " has header " << op->header)); } for (const Type& t : con->inputs) { diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index e9e2915f28a89..7b76b95e89f72 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -31,6 +31,7 @@ map = p.map foldl = p.foldl foldr = p.foldr +foldr1 = p.foldr1 sum = p.sum concat = p.concat @@ -228,6 +229,23 @@ def test_foldr(): assert count(same[0]) == 1 and count(same[1]) == 2 and count(same[2]) == 3 +def test_foldr1(): + a = relay.TypeVar("a") + lhs = mod[p.foldr1].checked_type + rhs = relay.FuncType([relay.FuncType([a, a], a), l(a)], a, [a]) + assert lhs == rhs + + x = relay.Var("x") + y = relay.Var("y") + f = relay.Function([x, y], add(x, y)) + res = intrp.evaluate(foldr1(f, + cons(build_nat(1), + cons(build_nat(2), + cons(build_nat(3), nil()))))) + + assert count(res) == 6 + + def test_sum(): assert mod[sum].checked_type == relay.FuncType([l(nat())], nat()) res = intrp.evaluate(sum(cons(build_nat(1), cons(build_nat(2), nil())))) @@ -647,6 +665,7 @@ def test_iterate(): test_map() test_foldl() test_foldr() + test_foldr1() test_concat() test_filter() test_zip()