Skip to content

Commit

Permalink
[TOPI] Fix declaration for different dtypes (#546)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang authored and tqchen committed Oct 13, 2017
1 parent b384cd4 commit b20678b
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
4 changes: 3 additions & 1 deletion python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node
from . import make as _make
from . import _api_internal

class ExprOp(object):
def __add__(self, other):
Expand Down Expand Up @@ -60,7 +61,8 @@ def __mod__(self, other):
return _make.Mod(self, other)

def __neg__(self):
return self.__mul__(-1)
neg_one = _api_internal._const(-1, self.dtype)
return self.__mul__(neg_one)

def __lshift__(self, other):
return _make.Call(self.dtype, "shift_left", [self, other], Call.PureIntrinsic, None, 0)
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/nn/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def relu(x):
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.max(x(*i), 0))
return tvm.compute(x.shape, lambda *i: tvm.max(x(*i), tvm.const(0, x.dtype)))


@tvm.tag_scope(tag=tag.ELEMWISE)
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/nn/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def global_pool(data, pool_type):
tvm.sum(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \
tag="global_pool_sum")
return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tsum[n, c, h, w] / (height*width), \
tsum[n, c, h, w] / (height*width).astype(tsum.dtype), \
tag=tag.ELEMWISE)
else:
raise ValueError("Pool type should be 'avg' or 'max'.")
Expand Down

0 comments on commit b20678b

Please sign in to comment.