Skip to content

Commit

Permalink
[Relay] Add foldr1
Browse files Browse the repository at this point in the history
  • Loading branch information
Li Xiaoquan committed Mar 30, 2019
1 parent 8eef156 commit 28ea075
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
24 changes: 24 additions & 0 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,29 @@ def define_list_foldr(self):
self.mod[self.foldr] = Function([f, bv, av],
Match(av, [nil_case, cons_case]), b, [a, b])

def define_list_foldr1(self):
"""Defines a right-way fold over an un-empty list.
foldr1(f, l) : fn<a>(fn(a, a) -> a, list[a]) -> a
foldr1(f, cons(a1, cons(a2, cons(..., cons(an, nil)))))
evalutes to f(a1, f(a2, f(..., f(an-1, an)))...)
"""
self.foldr1 = GlobalVar("foldr1")
a = TypeVar("a")
f = Var("f", FuncType([a, a], a))
av = Var("av", self.l(a))
x = Var("x")
y = Var("y")
z = Var("z")
one_case = Clause(PatternConstructor(self.cons,
[PatternVar(x), PatternConstructor(self.nil)]), x)
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]),
f(y, self.foldr1(f, z)))
self.mod[self.foldr1] = Function([f, av],
Match(av, [one_case, cons_case]), a, [a])


def define_list_concat(self):
"""Defines a function that concatenates two lists.
Expand Down Expand Up @@ -471,6 +494,7 @@ def __init__(self, mod):
self.define_list_map()
self.define_list_foldl()
self.define_list_foldr()
self.define_list_foldr1()
self.define_list_concat()
self.define_list_filter()
self.define_list_zip()
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/kind_check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
for (const auto& con : op->constructors) {
if (!con->belong_to.same_as(op->header)) {
ReportFatalError(RELAY_ERROR(con << " has header " << con->belong_to
<< " but " << op << "has header " << op->header));
<< " but " << op << " has header " << op->header));
}

for (const Type& t : con->inputs) {
Expand Down
19 changes: 19 additions & 0 deletions tests/python/relay/test_adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
map = p.map
foldl = p.foldl
foldr = p.foldr
foldr1 = p.foldr1
sum = p.sum

concat = p.concat
Expand Down Expand Up @@ -228,6 +229,23 @@ def test_foldr():
assert count(same[0]) == 1 and count(same[1]) == 2 and count(same[2]) == 3


def test_foldr1():
a = relay.TypeVar("a")
lhs = mod[p.foldr1].checked_type
rhs = relay.FuncType([relay.FuncType([a, a], a), l(a)], a, [a])
assert lhs == rhs

x = relay.Var("x")
y = relay.Var("y")
f = relay.Function([x, y], add(x, y))
res = intrp.evaluate(foldr1(f,
cons(build_nat(1),
cons(build_nat(2),
cons(build_nat(3), nil())))))

assert count(res) == 6


def test_sum():
assert mod[sum].checked_type == relay.FuncType([l(nat())], nat())
res = intrp.evaluate(sum(cons(build_nat(1), cons(build_nat(2), nil()))))
Expand Down Expand Up @@ -647,6 +665,7 @@ def test_iterate():
test_map()
test_foldl()
test_foldr()
test_foldr1()
test_concat()
test_filter()
test_zip()
Expand Down

0 comments on commit 28ea075

Please sign in to comment.