Skip to content

Commit

Permalink
[test] math intrin
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Mar 11, 2020
1 parent 08d961b commit fa52eac
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@

from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil
from .op import exp, exp2, exp10, log, log2, log10
from .op import cos, sin, cosh, sinh, tan, tanh, atan
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil
from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum
Expand Down
98 changes: 98 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,38 @@ def exp(x):
return call_pure_intrin(x.dtype, "exp", x)


def exp2(x):
"""Calculate 2**x
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "exp2", x)


def exp10(x):
"""Calculate 10**x
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "exp10", x)


def erf(x):
"""Take gauss error function of the input x.
Expand Down Expand Up @@ -393,6 +425,38 @@ def log(x):
"""
return call_pure_intrin(x.dtype, "log", x)


def log2(x):
"""Take log2 of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "log2", x)


def log10(x):
"""Take log10 of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "log10", x)

def tan(x):
"""Take tan of input x.
Expand Down Expand Up @@ -424,6 +488,23 @@ def cos(x):
"""
return call_pure_intrin(x.dtype, "cos", x)


def cosh(x):
"""Take cosh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "cosh", x)


def sin(x):
"""Take sin of input x.
Expand All @@ -439,6 +520,23 @@ def sin(x):
"""
return call_pure_intrin(x.dtype, "sin", x)


def sinh(x):
"""Take sin of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "sinh", x)


def atan(x):
"""Take atan of input x.
Expand Down
22 changes: 22 additions & 0 deletions tests/python/unittest/test_tvm_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,27 @@ def test_nearbyint():
a_rounded.asnumpy(), np.rint(a.asnumpy()))


def test_unary_intrin(tvm_intrin, np_func):
m = te.var("m",)
A = te.placeholder((m,), name='A')
B = te.compute((m,), lambda *i: tvm_intrin(A(*i)), name='B')
s = te.create_schedule(B.op)
f = tvm.build(s, [A, B], "llvm")
ctx = tvm.cpu(0)
n = 10
a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx)
b = tvm.nd.array( \
np.random.uniform(size=n).astype(A.dtype), ctx)
f(a, b)
tvm.testing.assert_allclose(
b.asnumpy(), np_func(a.asnumpy()), atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
test_nearbyint()
test_unary_intrin(tvm.tir.exp2, lambda x : np.exp2(x))
test_unary_intrin(tvm.tir.exp10, lambda x : np.power(10, x))
test_unary_intrin(tvm.tir.log2, lambda x : np.log2(x))
test_unary_intrin(tvm.tir.log10, lambda x : np.log10(x))
test_unary_intrin(tvm.tir.sinh, lambda x : np.sinh(x))
test_unary_intrin(tvm.tir.cosh, lambda x : np.cosh(x))

0 comments on commit fa52eac

Please sign in to comment.