Skip to content

Commit

Permalink
[Relay, Topi, TF Frontend] Isfinite operator (#4981)
Browse files Browse the repository at this point in the history
* isfinite doc update

* isfinit expr

* isfinit expr

* isfinite schedule reg

* isfinite python binding

* isfinite python binding

* relay register isfinite

* isfinite type relation

* intrin isfinite

* topi isfinite

* testcase topi isfinite

* tf frontend isfinite

* tf frontend isfinite testcase

* test case relay isfinite

* small fixes

* test forward tf isfinite

* test cases injective for cuda

* remove float16 test case

* add support for isinf

* remove unwanted import

* fix conflict
  • Loading branch information
maheshambule authored Mar 23, 2020
1 parent fdc8b0d commit 9037f4e
Show file tree
Hide file tree
Showing 20 changed files with 319 additions and 4 deletions.
4 changes: 4 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ List of operators
topi.round
topi.abs
topi.isnan
topi.isfinite
topi.isinf
topi.exp
topi.tanh
topi.log
Expand Down Expand Up @@ -134,6 +136,8 @@ topi
.. autofunction:: topi.round
.. autofunction:: topi.abs
.. autofunction:: topi.isnan
.. autofunction:: topi.isfinite
.. autofunction:: topi.isinf
.. autofunction:: topi.exp
.. autofunction:: topi.tanh
.. autofunction:: topi.log
Expand Down
2 changes: 2 additions & 0 deletions docs/frontend/tensorflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ Supported Ops
- Greater
- GreaterEqual
- Identity
- IsFinite
- IsInf
- LeakyRelu
- LeftShift
- Less
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,8 @@ class CallNode : public PrimExprNode {
static constexpr const char* glsl_texture_store = "glsl_texture_store";
static constexpr const char* prefetch = "prefetch";
static constexpr const char* isnan = "isnan";
static constexpr const char* isfinite = "isfinite";
static constexpr const char* isinf = "isinf";

/*! \brief Vectorizable intrinsic list. */
static const char* vectorizable_intrinsics[];
Expand Down
21 changes: 21 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ TVM_DLL PrimExpr max_value(const DataType& dtype);
*/
TVM_DLL PrimExpr min_value(const DataType& dtype);

/*!
* Get the value of infinity.
* \param dtype The data type.
* \return the infinity value in this format.
*/
TVM_DLL PrimExpr infinity(const DataType& dtype);

/*!
* \brief cast value to type.
*
Expand Down Expand Up @@ -439,6 +446,20 @@ TVM_DLL PrimExpr abs(PrimExpr x);
*/
TVM_DLL PrimExpr isnan(PrimExpr x);

/*!
* \brief Check if x is finite.
* \param x The input data
* \return The result expression.
*/
TVM_DLL PrimExpr isfinite(PrimExpr x);

/*!
* \brief Check if x is infinite.
* \param x The input data
* \return The result expression.
*/
TVM_DLL PrimExpr isinf(PrimExpr x);

/*!
* \brief sum of of source expression over axis
* \param source The source expression.
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1667,6 +1667,8 @@ def _impl(inputs, attr, params):
'Greater' : _broadcast('greater'),
'GreaterEqual' : _broadcast('greater_equal'),
'Identity' : _identity(),
'IsFinite' : AttrCvt('isfinite'),
'IsInf' : AttrCvt('isinf'),
'LeakyRelu' : AttrCvt('leaky_relu'),
'LeftShift' : AttrCvt('left_shift'),
'Less' : _broadcast('less'),
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
register_broadcast_schedule("less_equal")
register_broadcast_schedule("greater")
register_broadcast_schedule("greater_equal")
register_broadcast_schedule("isfinite")
register_broadcast_schedule("isinf")
register_injective_schedule("maximum")
register_injective_schedule("minimum")
register_injective_schedule("right_shift")
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,3 +1008,35 @@ def ndarray_size(data, dtype="int32"):
The number of elements of input tensor.
"""
return _make.ndarray_size(data, dtype)


def isfinite(data):
"""Compute element-wise finiteness of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.isfinite(data)


def isinf(data):
"""Compute element-wise infiniteness of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.isinf(data)
3 changes: 2 additions & 1 deletion python/tvm/te/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# expose all operators in tvm tir.op
from tvm.tir import any, all, min_value, max_value, trace
from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil
from tvm.tir import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else
from tvm.tir import isnan, isfinite, isinf
from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from tvm.tir import comm_reducer, min, max, sum

Expand Down
3 changes: 2 additions & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
from .op import exp, exp2, exp10, log, log2, log10
from .op import cos, sin, cosh, sinh, tan, tanh, atan
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil
from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from .op import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else
from .op import isnan, isfinite, isinf
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum

Expand Down
32 changes: 32 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,38 @@ def isnan(x):
return _ffi_api.isnan(x)


def isfinite(x):
"""Check if input value is finite.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _ffi_api.isfinite(x)


def isinf(x):
"""Check if input value is infinite.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _ffi_api.isinf(x)


def power(x, y):
"""x power y
Expand Down
18 changes: 18 additions & 0 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -415,5 +415,23 @@ ElemwiseArbitraryLayout)
.set_support_level(10)
.set_attr<FTVMCompute>("FTVMCompute", NdarraySizeCompute);

RELAY_REGISTER_UNARY_OP("isfinite")
.describe(R"code(Returns the finiteness of input, computed element-wise.
.. math::
isfinite(x)
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.add_type_rel("IdentityCompRel", IdentityCompRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isfinite));

RELAY_REGISTER_UNARY_OP("isinf")
.describe(R"code(Returns the infiniteness of input, computed element-wise.
.. math::
isfinite(x)
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.add_type_rel("IdentityCompRel", IdentityCompRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isinf));

} // namespace relay
} // namespace tvm
12 changes: 12 additions & 0 deletions src/relay/op/type_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ bool BroadcastCompRel(const Array<Type>& types,
return false;
}

bool IdentityCompRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
if (auto* t0 = types[0].as<TensorTypeNode>()) {
Type out_type = TensorType(GetRef<TensorType>(t0)->shape, DataType::Bool());
reporter->Assign(types[1], out_type);
return true;
}
return false;
}

Array<IndexExpr> RankShape(const Array<IndexExpr>& shape) {
if (shape.size() == 0) {
return {};
Expand Down
5 changes: 5 additions & 0 deletions src/relay/op/type_relations.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ bool BroadcastCompRel(const Array<Type>& types,
const Attrs& attrs,
const TypeReporter& reporter);

bool IdentityCompRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter);

Array<IndexExpr> RankShape(const Array<IndexExpr>& shape);

} // namespace relay
Expand Down
16 changes: 16 additions & 0 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,22 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
*rv = one / (one + exp(-call->args[0]));
});

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isfinite")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
*rv = isfinite(call->args[0]);
});

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isinf")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
*rv = isinf(call->args[0]);
});

} // namespace intrin
} // namespace codegen
} // namespace tvm
36 changes: 36 additions & 0 deletions src/tir/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,21 @@ PrimExpr min_value(const DataType& dtype) {
return PrimExpr();
}

// infinity
PrimExpr infinity(const DataType& dtype) {
using namespace tir;
CHECK_EQ(dtype.lanes(), 1);
if (dtype.is_float()) {
if (dtype.bits() == 64) {
return FloatImm(dtype, std::numeric_limits<double>::infinity());
} else if (dtype.bits() == 32 || dtype.bits() == 16) {
return FloatImm(dtype, std::numeric_limits<float>::infinity());
}
}
LOG(FATAL) << "Cannot decide infinity for type " << dtype;
return PrimExpr();
}

namespace tir {
template<typename ValueType>
inline bool ConstPowerHelper(ValueType val, int *shift) {
Expand Down Expand Up @@ -575,6 +590,21 @@ PrimExpr isnan(PrimExpr x) {
}
}

PrimExpr isinf(PrimExpr x) {
DataType t = DataType::Bool(x.dtype().lanes());
if (x.dtype().is_int() || x.dtype().is_uint()) {
return make_const(t, false);
} else if (x.dtype().is_float()) {
PrimExpr infX = infinity(x.dtype());
return abs(x) == infX && !isnan(x);
} else {
LOG(FATAL) << "Data type " << x.dtype() << " not supported for finiteness ops. Skipping it...";
return x;
}
}

PrimExpr isfinite(PrimExpr x) { return !isinf(x) && !isnan(x); }

PrimExpr sum(PrimExpr source, Array<IterVar> rdom) {
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::AddNode::make(x, y);
Expand Down Expand Up @@ -721,6 +751,12 @@ TVM_REGISTER_GLOBAL("tir.abs")
TVM_REGISTER_GLOBAL("tir.isnan")
.set_body_typed(tvm::isnan);

TVM_REGISTER_GLOBAL("tir.isfinite")
.set_body_typed(tvm::isfinite);

TVM_REGISTER_GLOBAL("tir.isinf")
.set_body_typed(tvm::isinf);

TVM_REGISTER_GLOBAL("tir.floor")
.set_body_typed(tvm::floor);

Expand Down
34 changes: 33 additions & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3152,7 +3152,37 @@ def test_forward_dilation():
_test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 2, 2, 1], "SAME")
_test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 1, 2, 1], "VALID")

# #######################################################################

#######################################################################
# infinity ops
# ------------
def _verify_infiniteness_ops(tf_op, name):
"""test operator infinity ops"""

# Only float types are allowed in Tensorflow for isfinite and isinf
# float16 is failing on cuda
tf_dtypes = ["float32", "float64"]
for tf_dtype in tf_dtypes:
shape = (8, 8)
data = np.random.uniform(size=shape).astype(tf_dtype)
data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] = np.infty
data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] = np.nan

tf.reset_default_graph()
in_data = tf.placeholder(tf_dtype, shape, name="in_data")
tf_op(in_data, name=name)
compare_tf_with_tvm([data], ['in_data:0'], '{}:0'.format(name))


def test_forward_isinf():
_verify_infiniteness_ops(tf.is_inf, "isinf")


def test_forward_isfinite():
_verify_infiniteness_ops(tf.is_finite, "isfinite")


#######################################################################
# Main
# ----
if __name__ == '__main__':
Expand Down Expand Up @@ -3224,6 +3254,8 @@ def test_forward_dilation():
test_forward_squared_difference()
test_forward_add_n()
test_forward_floormod()
test_forward_isfinite()
test_forward_isinf()
test_forward_unravel_index()

# Reductions
Expand Down
Loading

0 comments on commit 9037f4e

Please sign in to comment.