Skip to content

Commit

Permalink
[Relay][FastMath] Relay pass to use fast exp/tanh (apache#4873)
Browse files Browse the repository at this point in the history
* [Relay][FastMath] Relay pass to use fast exp/tanh

* Adding required_pass to the tests.

* FastMath test changes.
  • Loading branch information
anijain2305 authored and zhiics committed Apr 17, 2020
1 parent 30e3959 commit 8027b6a
Show file tree
Hide file tree
Showing 10 changed files with 211 additions and 6 deletions.
7 changes: 7 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ TVM_DLL Pass PartialEval();
*/
TVM_DLL Pass SimplifyInference();

/*!
* \brief Replaces non linear activation functions with their fast but approximate counterparts.
*
* \return The Pass.
*/
TVM_DLL Pass FastMath();

/*!
* \brief Infer the type of an expression.
*
Expand Down
16 changes: 14 additions & 2 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def build_config(opt_level=2,
"CanonicalizeCast": 3,
"EliminateCommonSubexpr": 3,
"CombineParallelConv2D": 4,
"CombineParallelDense": 4
"CombineParallelDense": 4,
"FastMath": 4
}
fallback_device : int, str, or tvmContext, optional
Expand Down Expand Up @@ -175,11 +176,22 @@ def SimplifyInference():
Returns
-------
ret: tvm.relay.Pass
The registered to perform operator simplification.
The registered pass to perform operator simplification.
"""
return _transform.SimplifyInference()


def FastMath():
""" Converts the expensive non linear functions to their fast but approximate counterparts.
Returns
-------
ret: tvm.relay.Pass
The registered pass to perform fast math operations.
"""
return _transform.FastMath()


def CanonicalizeOps():
"""Canonicalize special operators to basic operators.
This can simplify followed analysis, e.g. expanding bias_add to
Expand Down
3 changes: 3 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ class RelayBuildModule : public runtime::ModuleNode {
if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout());
}

// Fast math optimizations.
pass_seqs.push_back(transform::FastMath());
pass_seqs.push_back(transform::FoldConstant());

// Create a sequential pass and perform optimizations.
Expand Down
22 changes: 22 additions & 0 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ RELAY_REGISTER_UNARY_OP("exp")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));


RELAY_REGISTER_UNARY_OP("fast_exp")
.describe(R"code(Returns the fast_exp input array, computed element-wise.
.. math::
\fast_exp(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp));


RELAY_REGISTER_UNARY_OP("erf")
.describe(R"code(Returns the error function value for input array, computed element-wise.
Expand Down Expand Up @@ -250,6 +261,17 @@ RELAY_REGISTER_UNARY_OP("tanh")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh));


RELAY_REGISTER_UNARY_OP("fast_tanh")
.describe(R"code(Returns the fast_tanh of input array, computed element-wise.
.. math::
Y = sinh(X) / cosh(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh));


RELAY_REGISTER_UNARY_OP("negative")
.describe(R"code(Returns the numeric negative of input array, computed element-wise.
Expand Down
79 changes: 79 additions & 0 deletions src/relay/pass/fast_math.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file fast_math.cc
* \brief Replaces non linear activation functions with their fast but approximate counterparts.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/op.h>
#include "pattern_util.h"

namespace tvm {
namespace relay {

class FastMathMutator : public ExprMutator {
public:
FastMathMutator()
: exp_op_(Op::Get("exp")),
tanh_op_(Op::Get("tanh")) {}

Expr VisitExpr_(const CallNode* n) {
auto new_n = ExprMutator::VisitExpr_(n);
if (n->op == exp_op_) {
return FastExp(new_n.as<CallNode>()->args[0]);
} else if (n->op == tanh_op_) {
return FastTanh(new_n.as<CallNode>()->args[0]);
}
return new_n;
}

private:
// Cache the following ops. They will be used in the passes repeatedly for
// operator equivalence checking so that the registry lookup overhead can be
// reduced.
const Op& exp_op_;
const Op& tanh_op_;
};

Expr FastMath(const Expr& e) {
return FastMathMutator().Mutate(e);
}

namespace transform {

Pass FastMath() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FastMath(f));
};
return CreateFunctionPass(pass_func, 4, "FastMath",
{tir::StringImmNode::make("InferType")});
}

TVM_REGISTER_GLOBAL("relay._transform.FastMath")
.set_body_typed(FastMath);

} // namespace transform

} // namespace relay
} // namespace tvm
10 changes: 10 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,16 @@ inline Expr Exp(Expr e) {
return CallNode::make(op, {e});
}

inline Expr FastExp(Expr e) {
static const Op& op = Op::Get("fast_exp");
return CallNode::make(op, {e});
}

inline Expr FastTanh(Expr e) {
static const Op& op = Op::Get("fast_tanh");
return CallNode::make(op, {e});
}

inline Expr Log(Expr e) {
static const Op& op = Op::Get("log");
return CallNode::make(op, {e});
Expand Down
52 changes: 52 additions & 0 deletions tests/python/relay/test_pass_fast_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm.ir import IRModule
from tvm import relay
from tvm.relay.transform import FastMath

def test_exp():
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
y = relay.exp(x)
func = relay.Function([x], y)
mod = tvm.IRModule.from_expr(func)

fast_mod = FastMath()(mod)
assert "fast_exp" in fast_mod.astext()

# Check that FastMath option works for relay.build.
with relay.build_config(opt_level=3, required_pass=['FastMath']):
fast_mod = relay.optimize(mod, target='llvm', params=None)
assert "fast_exp" in fast_mod[0].astext()

def test_tanh():
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
y = relay.tanh(x)
func = relay.Function([x], y)
mod = tvm.IRModule.from_expr(func)

fast_mod = FastMath()(mod)
assert "fast_tanh" in fast_mod.astext()

# Check that FastMath option works for relay.build.
with relay.build_config(opt_level=3, required_pass=['FastMath']):
fast_mod = relay.optimize(mod, target='llvm', params=None)
assert "fast_tanh" in fast_mod[0].astext()

if __name__ == "__main__":
test_exp()
test_tanh()
7 changes: 4 additions & 3 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan);
TOPI_DECLARE_UNARY_OP(tanh);

/*
* \brief Fast_tanh_float implementation from Eigen
Expand Down Expand Up @@ -113,9 +114,9 @@ inline Tensor fast_tanh_float(const Tensor& in,
*
* \return A Tensor whose op member is tanh
*/
inline Tensor tanh(const Tensor& x,
std::string name = "T_tanh",
std::string tag = kElementWise) {
inline Tensor fast_tanh(const Tensor& x,
std::string name = "T_fast_tanh",
std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) {
// invoke fast_tanh_float implementation
return fast_tanh_float(x, name, tag);
Expand Down
16 changes: 16 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,19 @@ def fast_exp(x):
The result.
"""
return cpp.fast_exp(x, x.dtype, tag.ELEMWISE)


def fast_tanh(x):
"""Take tanhonential of input x using fast_tanh implementation
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE)
5 changes: 4 additions & 1 deletion topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]);
});

TVM_REGISTER_GLOBAL("topi.fast_tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = fast_tanh(args[0]);
});
TVM_REGISTER_GLOBAL("topi.atan")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = atan(args[0]);
Expand Down

0 comments on commit 8027b6a

Please sign in to comment.