Skip to content

Commit

Permalink
[Relay][FastMath] Relay pass to use fast exp/tanh
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Feb 13, 2020
1 parent aaf62e4 commit 1001f00
Show file tree
Hide file tree
Showing 9 changed files with 219 additions and 5 deletions.
7 changes: 7 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ TVM_DLL Pass PartialEval();
*/
TVM_DLL Pass SimplifyInference();

/*!
* \brief Replaces non linear activation functions with their fast but approximate counterparts.
*
* \return The Pass.
*/
TVM_DLL Pass FastMath();

/*!
* \brief Infer the type of an expression.
*
Expand Down
14 changes: 13 additions & 1 deletion python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def build_config(opt_level=2,
"CanonicalizeCast": 3,
"EliminateCommonSubexpr": 3,
"CombineParallelConv2D": 4,
"CombineParallelDense": 4
"CombineParallelDense": 4,
"FastMath": 4
}
fallback_device : int, str, or tvmContext, optional
Expand Down Expand Up @@ -179,6 +180,17 @@ def SimplifyInference():
return _transform.SimplifyInference()


def FastMath():
""" Converts the expensive non linear functions to their fast but approximate counterparts.
Returns
-------
ret: tvm.relay.Pass
The registered to perform operator simplification.
"""
return _transform.FastMath()


def CanonicalizeOps():
"""Canonicalize special operators to basic operators.
This can simplify followed analysis, e.g. expanding bias_add to
Expand Down
4 changes: 4 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ class RelayBuildModule : public runtime::ModuleNode {
}
pass_seqs.push_back(transform::FoldConstant());

// Fast math optimizations.
pass_seqs.push_back(transform::FastMath());
pass_seqs.push_back(transform::FoldConstant());

// Create a sequential pass and perform optimizations.
transform::Pass seq = transform::Sequential(pass_seqs);
if (targets.size() == 1) {
Expand Down
22 changes: 22 additions & 0 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ RELAY_REGISTER_UNARY_OP("exp")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));


RELAY_REGISTER_UNARY_OP("fastexp")
.describe(R"code(Returns the fastexp input array, computed element-wise.
.. math::
\fastexp(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fastexp));


RELAY_REGISTER_UNARY_OP("erf")
.describe(R"code(Returns the error function value for input array, computed element-wise.
Expand Down Expand Up @@ -251,6 +262,17 @@ RELAY_REGISTER_UNARY_OP("tanh")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh));


RELAY_REGISTER_UNARY_OP("fasttanh")
.describe(R"code(Returns the fasttanh of input array, computed element-wise.
.. math::
Y = sinh(X) / cosh(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fasttanh));


RELAY_REGISTER_UNARY_OP("negative")
.describe(R"code(Returns the numeric negative of input array, computed element-wise.
Expand Down
10 changes: 10 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,16 @@ inline Expr Exp(Expr e) {
return CallNode::make(op, {e});
}

inline Expr FastExp(Expr e) {
static const Op& op = Op::Get("fastexp");
return CallNode::make(op, {e});
}

inline Expr FastTanh(Expr e) {
static const Op& op = Op::Get("fasttanh");
return CallNode::make(op, {e});
}

inline Expr Log(Expr e) {
static const Op& op = Op::Get("log");
return CallNode::make(op, {e});
Expand Down
87 changes: 84 additions & 3 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan);
TOPI_DECLARE_UNARY_OP(tanh);

/*
* \brief Fast_tanh_float implementation from Eigen
Expand Down Expand Up @@ -113,9 +114,9 @@ inline Tensor fast_tanh_float(const Tensor& in,
*
* \return A Tensor whose op member is tanh
*/
inline Tensor tanh(const Tensor& x,
std::string name = "T_tanh",
std::string tag = kElementWise) {
inline Tensor fasttanh(const Tensor& x,
std::string name = "T_fasttanh",
std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) {
// invoke fast_tanh_float implementation
return fast_tanh_float(x, name, tag);
Expand Down Expand Up @@ -377,5 +378,85 @@ inline Tensor full_like(const Tensor& x,
}, name, tag);
}

/*!
* \brief Fast exponential function implementation
*
* \param _x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is exponent operation
*
* \note Function computes:
* log2(e^x) = x * log2(e) * log2(2) =>
* log2(e^x) = log2(2^(x*log2(e))) =>
* e^x = 2^(x*log2(e))
* Splitting power x*log2(e) into integer and fractional parts:
* e^(n+f) = e^n * e^f
* n = floor(x*log2(e) + 1/2)
* f = x - n * ln(2)
* exp(x) = 2^n * exp(y)
* Approximation for fractional part:
* y = exp(f) = 1 + 2 * P(x**2)/(Q(x**2) - P(x**2))
*/
inline Tensor fast_exp_float32(const Tensor& _x,
std::string name,
std::string tag) {
auto x_hi = make_const(DataType::Float(32), 88.3762626647950f);
auto x_lo = make_const(DataType::Float(32), -88.3762626647949f);
auto log2e = make_const(DataType::Float(32), 1.44269504088896341f);
auto ln2 = make_const(DataType::Float(32), 0.6931471805599453f);
PrimExpr p[6] = {make_const(DataType::Float(32), 1.9875691500E-4f),
make_const(DataType::Float(32), 1.3981999507E-3f),
make_const(DataType::Float(32), 8.3334519073E-3f),
make_const(DataType::Float(32), 4.1665795894E-2f),
make_const(DataType::Float(32), 1.6666665459E-1f),
make_const(DataType::Float(32), 5.0000001201E-1f)};
auto one = make_const(DataType::Float(32), 1.0f);
auto one_half = make_const(DataType::Float(32), 0.5f);
auto b = make_const(DataType::Float(32), 127.0f);

return compute(_x->shape,
[&](const Array<Var>& i) {
// clamp x
auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo);
// integer part
auto n = ::tvm::floor(x * log2e + one_half);
// fractional part
auto f = x - n * ln2;
auto y = (((((p[0] * f + p[1]) * f + p[2]) * f + p[3])* f+ p[4]) * f
+ p[5]) * f * f + f + one;
// Return 2^m * exp(r).
auto ef = tvm::reinterpret(DataType::Float(32),
::tvm::cast(DataType::Int(32), n + b) << 23);
return ::tvm::max(ef * y, _x(i)); // NOLINT(*)
},
name, tag);
}


/*!
* \brief Fast exponential function implementation
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is exponent operation
*
*/
inline Tensor fastexp(const Tensor& x,
std::string name = "T_fastexp",
std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) {
auto ret = fast_exp_float32(x, name, tag);
return ret;
} else {
return compute(x->shape, [&](const Array<Var>& i) {
return ::tvm::exp(x(i));
}, name, tag);
}
}

} // namespace topi
#endif // TOPI_ELEMWISE_H_
32 changes: 32 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,35 @@ def reinterpret(x, dtype):
The result.
"""
return cpp.reinterpret(x, dtype)


def fastexp(x):
"""Take exponential of input x using fastexp implementation
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return cpp.fastexp(x, x.dtype, tag.ELEMWISE)


def fasttanh(x):
"""Take tanhonential of input x using fasttanh implementation
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return cpp.fasttanh(x, x.dtype, tag.ELEMWISE)
10 changes: 9 additions & 1 deletion topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ TVM_REGISTER_GLOBAL("topi.exp")
*rv = exp(args[0]);
});

TVM_REGISTER_GLOBAL("topi.fastexp")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = fastexp(args[0]);
});

TVM_REGISTER_GLOBAL("topi.erf")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = erf(args[0]);
Expand All @@ -183,7 +188,10 @@ TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]);
});

TVM_REGISTER_GLOBAL("topi.fasttanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = fasttanh(args[0]);
});
TVM_REGISTER_GLOBAL("topi.atan")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = atan(args[0]);
Expand Down
38 changes: 38 additions & 0 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,45 @@ def verify(from_dtype, to_dtype, low=-100, high=100):
verify("bool", "int32")


def test_fastmath():
def test_apply(
func,
name,
f_numpy,
low,
high,
step,
dtype=tvm.float32
):
a_np = np.arange(low, high, step).astype(dtype)
b_np = f_numpy(a_np)
A = tvm.placeholder(a_np.shape, dtype=dtype, name="A")
B = func(A)
assert tuple(B.shape) == tuple(A.shape)

def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
func = tvm.build(s, [A, B], device, name=name)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros_like(b_np), ctx)
func(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)

check_device('llvm')
check_device('llvm -device=arm-cpu')


test_apply(topi.fastexp, "fastexp", np.exp,
low=-88, high=88,
step = 0.01)

if __name__ == "__main__":
test_util()
test_ewise()
test_cast()
test_fastmath()

0 comments on commit 1001f00

Please sign in to comment.