Skip to content

Commit

Permalink
Fix implementation of tan in cuda. Do not support tan for float16.
Browse files Browse the repository at this point in the history
Simplify topi/tests/python/test_topi_math. Add testing for tan with float32 and float64.

Finally implement tan as sin/cos in llvm.
  • Loading branch information
notoraptor committed Mar 10, 2020
1 parent 68e7088 commit b1d8cf2
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 70 deletions.
28 changes: 6 additions & 22 deletions src/target/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,31 +93,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount")

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
using tir::make_const;
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr);
PrimExpr x = call->args[0];
PrimExpr y = x;
DataType dtype = x.dtype();
const char* opName = nullptr;

if (!dtype.is_float()) {
LOG(FATAL) << "tan expects floating input";
}

if (dtype.bits() == 64) {
opName = "tan";
} else if (dtype.bits() == 32) {
opName = "tanf";
} else if (dtype.bits() == 16) {
opName = "tanf";
y = cast(DataType::Float(32, dtype.lanes()), x);
} else {
LOG(FATAL) << "tan cannot handle float" << dtype.bits();
}

PrimExpr tan_x = tir::CallNode::make(x.dtype(), opName, {y}, tir::CallNode::Extern);
const PrimExpr& x = call->args[0];
PrimExpr sin_x = tir::CallNode::make(
x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic);
PrimExpr cos_x = tir::CallNode::make(
x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic);
PrimExpr tan_x = sin_x / cos_x;
*rv = tan_x;
});

Expand Down
18 changes: 17 additions & 1 deletion src/target/source/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ struct CUDAFastMath : public CUDAMath {
}
};

struct CUDAFastMathTan : public CUDAMath {
std::string operator()(DataType t, std::string name) const {
if (t.lanes() == 1 && t.is_float()) {
switch (t.bits()) {
case 64: return name;
// `__tanf` seems to produce some values too deviant from numpy tan version.
// So, let's use just `tanf` instead.
case 32: return name + 'f';
case 16: LOG(FATAL) << "cuda tan unsupported for float16";
default: return "";
}
}
return "";
}
};

struct CUDAPopcount {
std::string operator()(DataType t, std::string name) const {
if (t.lanes() == 1 && t.is_uint()) {
Expand Down Expand Up @@ -98,7 +114,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan")
.set_body(DispatchExtern<CUDAFastMath>);
.set_body(DispatchExtern<CUDAFastMathTan>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos")
.set_body(DispatchExtern<CUDAFastMath>);
Expand Down
49 changes: 2 additions & 47 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,58 +127,13 @@ def check_device(device):
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True)
test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float32')
test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float64')
test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32")
test_isnan(-100, 100)


def test_ewise_tan():
def test_apply(
func,
name,
f_numpy,
low,
high,
shape=(20, 3),
dtype='float32',
check_round=False,
skip_name_check=False,
):
A = te.placeholder(dtype=dtype, name="A", shape=shape)

B = func(A)
assert tuple(B.shape) == tuple(A.shape)
if not skip_name_check:
assert B.op.body[0].name == name
a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
# avoid round check too close to boundary
if check_round:
a_np += ((np.abs(np.fmod(a_np, 1)) - 0.5) < 1e-6) * 1e-5
b_np = f_numpy(a_np)

def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.testing.get_injective_schedule(device)(B)
foo = tvm.build(s, [A, B], device, name=name)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros_like(b_np), ctx)
foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)

for target in get_all_backend():
check_device(target)

test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float64')
test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float32')
test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float16')


def test_cast():
def verify(from_dtype, to_dtype, low=-100, high=100):
shape = (5, 4)
Expand Down

0 comments on commit b1d8cf2

Please sign in to comment.