Skip to content

Commit

Permalink
Fast exponent
Browse files Browse the repository at this point in the history
  • Loading branch information
alexgl-github committed Jan 31, 2020
1 parent 1b8522e commit 1831312
Showing 1 changed file with 60 additions and 1 deletion.
61 changes: 60 additions & 1 deletion topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ using namespace tvm::te;
}, name, tag); \
}

TOPI_DECLARE_UNARY_OP(exp);
TOPI_DECLARE_UNARY_OP(erf);
TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt);
Expand Down Expand Up @@ -360,5 +359,65 @@ inline Tensor full_like(const Tensor& x,
}, name, tag);
}

/*
* \brief Fast exponential function implementation
* log2(e^x) = x * log2(e) * log2(2) =>
* log2(e^x) = log2(2^(x*log2(e))) =>
* e^x = 2^(x*log2(e))
* Splitting power x*log2(e) into integer and fractional parts:
* e^(int_part+frac_part) = e^int_part * e^frac_part
* n = floor(x*log2(e) + 1/2)
* f = x - n * ln(2)
* exp(x) = 2^n * exp(y)
* Approximation for fractional part:
* y = exp(f) = 1 + 2 * P(x**2)/(Q(x**2) - P(x**2))
*/
inline Tensor fast_exp(const Tensor& _x,
std::string name,
std::string tag) {
auto one = make_const(DataType::Float(32), 1.0f);
auto one_half = make_const(DataType::Float(32), 0.5f);
auto b = make_const(DataType::Float(32), 127.0f);
auto x_hi = make_const(DataType::Float(32), 88.3762626647950f);
auto x_lo = make_const(DataType::Float(32), -88.3762626647949f);
auto log2e = make_const(DataType::Float(32), 1.44269504088896341f);
auto ln2 = make_const(DataType::Float(32), 0.6931471805599453f);
PrimExpr p[6] = {make_const(DataType::Float(32), 1.9875691500E-4f),
make_const(DataType::Float(32), 1.3981999507E-3f),
make_const(DataType::Float(32), 8.3334519073E-3f),
make_const(DataType::Float(32), 4.1665795894E-2f),
make_const(DataType::Float(32), 1.6666665459E-1f),
make_const(DataType::Float(32), 5.0000001201E-1f)};

return compute(_x->shape,
[&](const Array<Var>& i) {
// clamp x
auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo);
// integer part
auto n = ::tvm::floor(x * log2e + one_half);
// fractional part
auto f = x - n * ln2;
auto y = (((((p[0] * f + p[1]) * f + p[2]) * f + p[3]) * f + p[4]) * f + p[5]) * f * f + f + one;
// Return 2^m * exp(r).
auto ef = tvm::reinterpret(DataType::Float(32),
::tvm::cast(DataType::Int(32), n + b) << 23);
return ::tvm::max(ef * y, _x(i));
},
name, tag);
}


inline Tensor exp(const Tensor& x,
std::string name = "T_exp",
std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) {
return fast_exp(x, name, tag);
} else {
return compute(x->shape, [&](const Array<Var>& i) {
return ::tvm::exp(x(i));
}, name, tag);
}
}

} // namespace topi
#endif // TOPI_ELEMWISE_H_

0 comments on commit 1831312

Please sign in to comment.