Skip to content

Commit

Permalink
Fixed point multiplication improvements for AArch64 (apache#5980)
Browse files Browse the repository at this point in the history
* Fixed point multiplication improvements for AArch64

Change-Id: Ib3c10348d4c0eac11fa92b39cc6e792560e9eba4

* Fix python linting errors

Change-Id: I4cf5ac18aa24b39374b83805dcc8e1663e173909

* Fix doxygen errors

Change-Id: Ie3c861f8ead3f1ea5b30d5e9d7d94e222299d407

* Fix arm_cpu injective tests

Change-Id: I6ad9da61b61e6bd737627f26fba59767418c07cd

* Fix python linting errors - 2

Change-Id: Ic864a235aa5da5786393cbf6146dd815c121df5e

* Fix arm_cpu injective tests - 2

Change-Id: If9ca1cc3d947b1656c836c7f88de90470d92f979

* Redesign: introduce a qmuls (q-multiply and shift) general intrinsic

Change-Id: I1966fef9aee32eab50e4b984bbe81018488c8c02

* Fix python linting errors - 3

Change-Id: Ib87a19a8ee2d532954a7db1eb5793666e7aef366

* Addressing review comments

Change-Id: Ie82e75204e5a421d17660f381f3e31fc325cd26c

* Fixing test failures

Change-Id: I74cc675764cf8d260fe68a41e770b1ec7e84729a

* Renaming qmuls to q_multiply_shift

Change-Id: I5a8ed60ba855208040304fcdf6e1ea28061f06ad
  • Loading branch information
Giuseppe Rossini authored and Trevor Morris committed Sep 2, 2020
1 parent 25ec71b commit 3075163
Show file tree
Hide file tree
Showing 23 changed files with 382 additions and 52 deletions.
13 changes: 13 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,19 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
}
};

/*! \brief Attributes for FixedPointMultiply operator */
struct FixedPointMultiplyAttrs : public tvm::AttrsNode<FixedPointMultiplyAttrs> {
int32_t multiplier;
int32_t shift;

TVM_DECLARE_ATTRS(FixedPointMultiplyAttrs, "relay.attrs.FixedPointMultiplyAttrs") {
TVM_ATTR_FIELD(multiplier)
.describe("Multiplier of a fixed floating point number described as multiplier*2^(shift)");
TVM_ATTR_FIELD(shift).describe(
"Shift of a fixed floating point number described as multiplier*2^(shift)");
}
};

/*! \brief Attributes for LayoutTransform operator */
struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
std::string src_layout;
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ TVM_DLL const Op& shift_right();
*/
TVM_DLL const Op& large_uint_imm();

/*!
* \brief Execute a multiplication between two Q-numbers x and y
* followed by a right shift s
* The default rounding rule is to the nearest value, rounding half up
* (i.e., round(x.1) = x and round (x.5) = x+1)
*/
TVM_DLL const Op& q_multiply_shift();

/*!
* \brief See pesudo code
*
Expand Down
21 changes: 21 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,27 @@ TVM_DLL PrimExpr trunc(PrimExpr x);
*/
TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);

/*!
* \brief Execute a multiplication between two Q-numbers x and y
* followed by a right shift s. The mathematical expression is:
*
* out = round(x*y*2^-s)
*
* Please note that the two Q-numbers x and y are supposed to have
* the same number of fractional bits q.
*
* More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format)
*
* The rounding rule is to the nearest value, rounding half up
* (i.e., round(x.1) = x and round (x.5) = x+1)
* \param x first Q-number
* \param y second Q-number
* \param q number of fractional bits in x and y. Needs to be > 0
* \param s integer right shift
* \return The constructed expression.
*/
TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s);

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x) { \
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ def clip_compute(attrs, inputs, output_type):

register_injective_schedule("clip")

# fixed point multiply
@register_compute("fixed_point_multiply")
def fixed_point_multiply_compute(attrs, inputs, output_type):
assert len(inputs) == 1
return [topi.fixed_point_multiply(inputs[0], attrs.multiplier, attrs.shift)]

register_injective_schedule("fixed_point_multiply")

# full
@script
def _full_shape_func(shape):
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,27 @@ def clip(a, a_min, a_max):
"""
return _make.clip(a, a_min, a_max)

def fixed_point_multiply(data, multiplier, shift):
"""Fixed point multiplication between data and a fixed point
constant expressed as multiplier * 2^(-shift), where multiplier
is a Q-number with 31 fractional bits
Parameters
----------
data : relay.Expr
The input tensor.
multiplier : int
The integer multiplier of the fixed point constant.
a_max : float
The integer shift of the fixed point constant.
Returns
-------
result : relay.Expr
The output of the fixed point multiplication
"""
return _make.fixed_point_multiply(data, multiplier, shift)


def concatenate(data, axis):
"""Concatenate the input tensors along the given axis.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .op import isnan, isfinite, isinf, copysign
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift

from . import ir_builder
from . import transform
Expand Down
28 changes: 28 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,34 @@ def popcount(x):
"""
return call_intrin(x.dtype, "tir.popcount", x)

def q_multiply_shift(x, y, q, s):
"""Execute a multiplication between two Q-numbers x and y
followed by a right shift s. The mathematical expression is:
out = round(x*y*2^-s)
More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format)
The rounding rule is to the nearest value, rounding half up
(i.e., round(x.1) = x and round (x.5) = x+1)
Parameters
----------
x : PrimExpr
First Q-number
y : PrimExpr
Second Q-number
q : PrimExpr
Number of fractional bits in x and y. Needs to be > 0
s : PrimExpr
Integer shift
Returns
-------
y : PrimExpr
The result.
"""
return call_intrin('int32', "tir.q_multiply_shift", x, y, q, s)

def fmod(x, y):
"""Return the remainder of x divided by y with the same sign as x.
Expand Down
23 changes: 23 additions & 0 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,29 @@ This function takes a tensor, a minimum value `a_min`, and a maximum value `a_ma
.set_attrs_type<ClipAttrs>()
.set_support_level(3);

// relay.fixed_point_multiply
TVM_REGISTER_NODE_TYPE(FixedPointMultiplyAttrs);

TVM_REGISTER_GLOBAL("relay.op._make.fixed_point_multiply")
.set_body_typed([](Expr a, int32_t multiplier, int32_t shift) {
auto attrs = make_object<FixedPointMultiplyAttrs>();
attrs->multiplier = multiplier;
attrs->shift = shift;
static const Op& op = Op::Get("fixed_point_multiply");
return Call(op, {a}, Attrs(attrs), {});
});

RELAY_REGISTER_OP("fixed_point_multiply")
.describe(R"code(fixed point multiplication)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attrs_type<FixedPointMultiplyAttrs>()
.set_support_level(10);

RELAY_REGISTER_UNARY_OP("floor")
.describe(R"code(Returns the floor of input array, computed element-wise.
)code" TVM_ADD_FILELINE)
Expand Down
12 changes: 11 additions & 1 deletion src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,19 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
// Skip if input and output scales are same.
if (!IsEqualScalar(input_scale, output_scale)) {
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);

const bool is_upward_rounding = (param->rounding == "UPWARD");

// When using upward rounding (i.e., x.5 rounded to x+1), leverage
// the FixedPointMultiply operator
scaled_int32_t =
FixedPointMultiply(scaled_int32_t, double_multiplier, input_shape, param->rounding);
(is_upward_rounding
? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift)
: FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape));
}

} else {
// This is per-channel (per=axis) quantization.
std::vector<double> double_multipliers;
Expand Down
43 changes: 10 additions & 33 deletions src/relay/qnn/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,6 @@ namespace tvm {
namespace relay {
namespace qnn {

/*
* \brief Convert FP32 representation into fixed point representation.
* \param double_multplier The input FP32 number.
* \return The pair of multiplier and shift for fixed point representation.
* \note Converts a floating point number so that it can be represented by
* integers. The representation is
* float_number = (significand) * 2^(exponent)
*
* The significand is a number between 0.5 and 1. This is represented by
* an integer number. For example, if it is int32, then the decimal point
* exists between bit 31 and 30 from LSB (or between first and second bit
* from the left).
*
* Some examples are
* 0.25 = (0.5) * 2^(-1)
* 0.125 = (0.5) * 2^(-2)
*
* Credit to TFLite reference implementation.
*/
std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplier) {
int32_t significand, exponent;
if (double_multiplier == 0.) {
Expand All @@ -75,8 +56,8 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
return std::make_pair(significand, exponent);
}

Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>& input_shape,
const std::string& rounding) {
Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape) {
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
DataType hp_dtype = DataType::Int(64);
Expand Down Expand Up @@ -109,19 +90,15 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>&
int64_t pos_rounding_value = (1ll << (total_right_shift - 1));

Expr round_scalar;
if (rounding == "UPWARD") {
round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
} else if (rounding == "TONEAREST") {
auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);

auto zero_t = Zeros(input_shape, hp_dtype);
round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
} else {
LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
}
auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);

auto zero_t = Zeros(input_shape, hp_dtype);
round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);

// Add the rounding scalar.
tensor = Add(tensor, round_scalar);

Expand Down
32 changes: 26 additions & 6 deletions src/relay/qnn/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,27 @@ static inline int32_t GetQmax(const DataType& dtype) {
}
}

/*
* \brief Convert FP32 representation into fixed point representation.
* \param double_multplier The input FP32 number.
* \return The pair of multiplier and shift for fixed point representation.
* \note Converts a floating point number so that it can be represented by
* integers. The representation is
* float_number = (significand) * 2^(exponent)
*
* The significand is a number between 0.5 and 1. This is represented by
* an integer number. For example, if it is int32, then the decimal point
* exists between bit 31 and 30 from LSB (or between first and second bit
* from the left).
*
* Some examples are
* 0.25 = (0.5) * 2^(-1)
* 0.125 = (0.5) * 2^(-2)
*
* Credit to TFLite reference implementation.
*/
std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplier);

Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
Expand All @@ -94,13 +115,12 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) {

/*
* \brief Fixed point multiplication between integer tensor with floating point
scalar.
* scalar. This implementation rounds to the nearest value when it is midway
* between two representable values.
* \param tensor The quantized input tensor of dtype int64.
* \param multiplier The scalar multiplier.
* \param input_shape Shape of the input tensor.
* \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value
is midway between" "two representable values.
* \return The sequence of Relay ops for fixed point multiplication.
* \return The sequence of Relay ops for fixed point multiplication with TONEARES rounding.
* \note Original compuation is scale_fp32 * quantized_tensor. To convert into
* integer computation, the multiplication with fp32 scalar can be
Expand All @@ -114,8 +134,8 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) {
* 2) Round the result.
* 3) Right shift the result
*/
Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>& input_shape,
const std::string& rounding);
Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape);

/*
* \brief Fixed point multiplication between integer tensor with floating point
Expand Down
20 changes: 17 additions & 3 deletions src/relay/quantize/realize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,14 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
data = qnn::FixedPointMultiply(data, factor, data_shape, cfg->rounding);
if (cfg->rounding == "UPWARD") {
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = qnn::GetFixedPointMultiplierShift(factor);
data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);
} else {
data = qnn::FixedPointMultiplyToNearest(data, factor, data_shape);
}

return Cast(data, dtype);
}
}
Expand Down Expand Up @@ -164,8 +171,15 @@ Expr QuantizeRealize(const Call& ref_call, const Array<Expr>& new_args, const Ob
return QRealizeIntExpr(data, dom_scale, n->dtype);
} else {
data = Cast(data, DataType::Int(64));
data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm,
ref_call->type_as<TensorTypeNode>()->shape, cfg->rounding);
if (cfg->rounding == "UPWARD") {
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) =
qnn::GetFixedPointMultiplierShift(idom_scale_imm / odom_scale_imm);
data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);
} else {
data = qnn::FixedPointMultiplyToNearest(data, idom_scale_imm / odom_scale_imm,
ref_call->type_as<TensorTypeNode>()->shape);
}
data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype);
return QRealizeIntExpr(data, dom_scale, n->dtype);
}
Expand Down
8 changes: 8 additions & 0 deletions src/relay/transforms/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,14 @@ inline Expr Round(Expr x) {

inline Expr Clip(Expr x, double a_min, double a_max) { return MakeClip(x, a_min, a_max); }

inline Expr FixedPointMultiply(Expr x, int32_t multiplier, int32_t shift) {
static const Op& op = Op::Get("fixed_point_multiply");
auto attrs = make_object<FixedPointMultiplyAttrs>();
attrs->multiplier = multiplier;
attrs->shift = shift;
return Call(op, {x}, Attrs(attrs), {});
}

inline Expr Add(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("add");
return Call(op, {lhs, rhs}, Attrs(), {});
Expand Down
Loading

0 comments on commit 3075163

Please sign in to comment.