Skip to content

Commit

Permalink
Redesign
Browse files Browse the repository at this point in the history
Change-Id: I0f7dac8f5bf0efb7ec9bc8eb95475b0f2412fbd8
  • Loading branch information
Giuseppe Rossini committed Jul 6, 2020
1 parent 10c982b commit bb787cd
Show file tree
Hide file tree
Showing 11 changed files with 112 additions and 55 deletions.
2 changes: 1 addition & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);
* \param s integer shift
* \return The constructed expression.
*/
TVM_DLL PrimExpr fixed_point_multiply(PrimExpr x, PrimExpr m, PrimExpr s);
TVM_DLL PrimExpr fixed_point_multiply(PrimExpr x, PrimExpr y, int32_t n);

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
Expand Down
11 changes: 4 additions & 7 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from tvm.te.hybrid import script
import topi

from . import strategy
from .op import register_compute, register_shape_func
from .op import register_broadcast_schedule, register_injective_schedule
from .op import register_broadcast_schedule, register_injective_schedule, register_strategy
from .op import register_pattern, OpPattern


Expand Down Expand Up @@ -131,13 +132,9 @@ def clip_compute(attrs, inputs, output_type):

register_injective_schedule("clip")

# fixed point multiply
@register_compute("fixed_point_multiply")
def fixed_point_multiply_compute(attrs, inputs, output_type):
assert len(inputs) == 1
return [topi.fixed_point_multiply(inputs[0], attrs.multiplier, attrs.shift)]

register_injective_schedule("fixed_point_multiply")
# fixed point multiply
register_strategy("fixed_point_multiply", strategy.fixed_point_multiply_strategy)

# full
@script
Expand Down
12 changes: 11 additions & 1 deletion python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def schedule_concatenate_arm_cpu(_, outs, target):
"""schedule concatenate for arm cpu"""
with target:
return topi.arm_cpu.schedule_concatenate(outs)

@conv2d_strategy.register(["arm_cpu", "micro_dev"])
def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
"""conv2d arm cpu strategy"""
Expand Down Expand Up @@ -332,3 +332,13 @@ def schedule_bitserial_dense_arm_cpu(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.arm_cpu.schedule_bitserial_dense),
name="bitserial_dense.arm_cpu")
return strategy

@fixed_point_multiply_strategy.register("arm_cpu")
def schedule_fixed_point_multiply(attrs, inputs, out_type, target):
"""bitserial_dense arm cpu strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_fixed_point_multiply(topi.arm_cpu.fixed_point_multiply),
wrap_topi_schedule(topi.arm_cpu.schedule_fixed_point_multiply),
name="fixed_point_multiply.arm_cpu")
return strategy
17 changes: 17 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,23 @@ def softmax_strategy(attrs, inputs, out_type, target):
name="softmax.generic")
return strategy

# fixed_point_multiply
def wrap_compute_fixed_point_multiply(topi_compute):
"""Wrap softmax topi compute"""
def _compute_fixed_point_multiply(attrs, inputs, out_type):
return [topi_compute(inputs[0], attrs.multiplier, attrs.shift)]
return _compute_fixed_point_multiply

@override_native_generic_func("fixed_point_multiply_strategy")
def fixed_point_multiply_strategy(attrs, inputs, out_type, target):
"""fixed_point_multiply_strategy generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_fixed_point_multiply(topi.math.fixed_point_multiply),
wrap_topi_schedule(topi.generic.schedule_injective),
name="fixed_point_multiply.generic")
return strategy

# log_softmax
@generic_func
def schedule_log_softmax(attrs, outs, target):
Expand Down
44 changes: 12 additions & 32 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,42 +123,22 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.fixed_point_multiply")
const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr);

PrimExpr tensor = call->args[0];
PrimExpr fixed_point_multiplier = call->args[1];
PrimExpr shift = call->args[2];
PrimExpr x = call->args[0];
PrimExpr y = call->args[1];
PrimExpr n = call->args[2];

// Only int32 types are supported (any number of lanes is allowed)
CHECK(tensor.dtype().code() == DLDataTypeCode::kDLInt && tensor.dtype().bits() == 32);
CHECK(fixed_point_multiplier.dtype().code() == DLDataTypeCode::kDLInt &&
fixed_point_multiplier.dtype().bits() == 32);
CHECK(shift.dtype().code() == DLDataTypeCode::kDLInt && shift.dtype().bits() == 32);
CHECK(x.dtype().code() == DLDataTypeCode::kDLInt && x.dtype().bits() == 32);
CHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32);

DataType hp_dtype = DataType::Int(64, tensor.dtype().lanes());
DataType lp_dtype = DataType::Int(32, tensor.dtype().lanes());
DataType hp_dtype = DataType::Int(64, x.dtype().lanes());
DataType lp_dtype = DataType::Int(32, x.dtype().lanes());
PrimExpr K = (make_const(hp_dtype, 1) << (n - 1));

// 1) Calculating the integer multiplier and integer shift
PrimExpr zero = make_const(shift.dtype(), 0);
PrimExpr left_shift = tir::Select((shift > zero), shift, zero);
PrimExpr right_shift = tir::Select(shift > zero, zero, -shift);

// 2) Multiply the integer multiplier
tensor = tir::Select(left_shift != zero, tensor << cast(hp_dtype, left_shift),
cast(hp_dtype, tensor));

// 3) Perform the multiplication in higher precision.
tensor = tensor * fixed_point_multiplier;

// 4) Find the rounding scalar
PrimExpr total_right_shift = right_shift + 31;
PrimExpr pos_rounding_value = (make_const(hp_dtype, 1) << (total_right_shift - 1));

tensor = tensor + pos_rounding_value;

// 5) Simply right shift the result to get the final output.
tensor = tensor >> total_right_shift;

// 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
*rv = cast(lp_dtype, tensor);
x = cast(hp_dtype, x) * cast(hp_dtype, y);
x = x + K;
x = x >> n;
*rv = cast(lp_dtype, x);
});

} // namespace intrin
Expand Down
4 changes: 2 additions & 2 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) {
}

// fixed_point_multiply
PrimExpr fixed_point_multiply(PrimExpr x, PrimExpr m, PrimExpr s) {
return tir::Call(x.dtype(), tir::builtin::fixed_point_multiply(), {x, m, s});
PrimExpr fixed_point_multiply(PrimExpr x, PrimExpr y, int32_t n) {
return tir::Call(x.dtype(), tir::builtin::fixed_point_multiply(), {x, y, make_const(DataType::UInt(32), n)});
}

// The public function with a quick checking path.
Expand Down
2 changes: 2 additions & 0 deletions tests/python/relay/test_op_qnn_requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def test_downscale():
# Try positive values
# 8 corresponds to 0.5, resulting in 1
golden_data = np.arange(0, 32, 1).astype('int32')
print(golden_data)
golden_output = np.repeat([0, 1, 2], [8, 16, 8])
print(golden_output)
verify(mod, (golden_data, golden_output))

# Try negative values
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/arm_cpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
from .bitserial_dense import *
from .injective import *
from . import cortex_m7
from .math import fixed_point_multiply
31 changes: 31 additions & 0 deletions topi/python/topi/arm_cpu/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,37 @@ def schedule_injective(outs):
schedule_injective_from_existing(s, x)
return s

def schedule_fixed_point_multiply(outs):
"""ARM CPU schedule for injective op.
Parameters
----------
outs: Array of Tensor
The computation graph description of injective in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
x = outs[0]
ins = x.op.input_tensors
dtype = ins[0].dtype if len(ins) > 0 else x.dtype
max_vlen = 4 if dtype == 'int32' else 8

if list(s[x].op.axis):
# do not vectorize for broadcast
(io, ii) = s[x].split(list(s[x].op.axis)[-1], max_vlen)
s[x].vectorize(ii)
tvm.te.schedule.AutoInlineInjective(s)

if not is_empty_shape(x.shape):
schedule_injective_from_existing(s, x)
return s

def schedule_concatenate(outs):
"""Schedule for concatenate op.
Expand Down
29 changes: 22 additions & 7 deletions topi/python/topi/arm_cpu/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm
from tvm import te
from tvm.contrib import util, clang
from tvm.ir import Array, Op

def gemv_quantized_impl(M, N, data_type='uint8'):
""" Assembly implementation of a blocked gemv. Given
Expand Down Expand Up @@ -452,18 +453,22 @@ def _instr(index):
C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer},
default_buffer_params=buffer_params)

def _fixed_point_multiply_arm(op):
def fixed_point_multiply_arm_rule(op):
"""
Implementation of fixed point multiplication through arm
intrinsics sqrdmulh and srshl
"""
x = op.args[0]
multiplier = op.args[1]
shift = op.args[2]

# Don't use this intrinsic if we don't have a int32x4 vector
# Don't use AArch64 intrinsics if we don't have a int32x4 vector
if x.dtype != "int32x4":
return op
left_shift = tvm.tir.if_then_else(shift > 0, shift, 0)
right_shift = tvm.tir.if_then_else(shift > 0, 0, -shift)
x = tvm.tir.if_then_else(left_shift > 0, x << left_shift, x)
mulq = tvm.tir.fixed_point_multiply(x, multiplier, 31)
return mulq >> right_shift

# Case 1, shift is negative
sqrdmulh = tvm.tir.call_llvm_intrin(op.dtype,
Expand All @@ -477,7 +482,7 @@ def _fixed_point_multiply_arm(op):
out_1 = tvm.tir.call_llvm_intrin(op.dtype,
'llvm.aarch64.neon.srshl',
tvm.tir.const(2, 'uint32'),
sqrdmulh,
fixed_up_x,
shift)

# Case 2, shift is positive
Expand All @@ -491,6 +496,16 @@ def _fixed_point_multiply_arm(op):
# Select depending on the shift
return tvm.tir.Select(shift < 0, out_1, out_2)


def fixed_point_multiply_arm(x, m, s):
"""customized log intrinsic function"""
return tvm.tir.call_intrin(x.dtype, "tir.fixed_point_multiply_arm", x, m, s)


tvm.target.intrin.register_intrin_rule("llvm.aarch64",
"fixed_point_multiply",
_fixed_point_multiply_arm, override=True)
"fixed_point_multiply_arm",
fixed_point_multiply_arm_rule, override=False)

tvm.ir.register_op_attr("tir.fixed_point_multiply_arm", "TCallEffectKind", tvm.tir.CallEffectKind.Pure)
tvm.ir.register_op_attr("tir.fixed_point_multiply_arm", "TVectorizable", True)

14 changes: 9 additions & 5 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,6 @@ def _compute(*indices):
return tvm.te.max(tvm.te.min(value, const_max), const_min)
return te.compute(x.shape, _compute)

@tvm.te.tag_scope(tag=tag.ELEMWISE)
def fixed_point_multiply(x, multiplier, shift):
"""
Expand All @@ -628,15 +627,20 @@ def fixed_point_multiply(x, multiplier, shift):
y : tvm.te.Tensor
The result.
"""
left_shift = shift if shift > 0 else 0
right_shift = 0 if shift > 0 else -shift

def _compute(*indices):
value = x(*indices)
m = tvm.tir.const(multiplier, x.dtype)
s = tvm.tir.const(shift, x.dtype)
return tvm.tir.fixed_point_multiply(value, m, s)
val = x(*indices)
val = tvm.tir.if_then_else(left_shift > 0, val << left_shift, val)
mulq = tvm.tir.fixed_point_multiply(val, multiplier, 31)
nudge = (1 << (right_shift-1))
return (mulq+nudge) >> right_shift

assert x.dtype == "int32", "input tensor type needs to be int32"
return te.compute(x.shape, _compute)


def cast(x, dtype):
"""Cast input to specified data type.
Expand Down

0 comments on commit bb787cd

Please sign in to comment.