From d155aec10a801352f8e83367c6d43511188dee34 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Dec 2018 00:26:00 -0800 Subject: [PATCH] Add relay_bitpack.py (#36) --- vta/python/vta/top/relay_bitpack.py | 72 +++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 vta/python/vta/top/relay_bitpack.py diff --git a/vta/python/vta/top/relay_bitpack.py b/vta/python/vta/top/relay_bitpack.py new file mode 100644 index 000000000000..2265af4518b4 --- /dev/null +++ b/vta/python/vta/top/relay_bitpack.py @@ -0,0 +1,72 @@ +"""Bit packing operators""" +from __future__ import absolute_import as _abs + +import tvm +from topi import util + +from tvm.relay.op.op import register_compute, register_schedule +from tvm.relay.op.op import register_pattern, OpPattern +from tvm.relay.op.op import schedule_injective + +def bitpack(data, bits, pack_type="int8", name="bitpack"): + """Packs lowest dimension into format needed by VTA + + Parameters + ---------- + pack_axis : int + index of the axis to pack in data + bit_axis : int + index of axis to place bit axis in resulting packed data + + Returns + ------- + packed : Tensor + The packed tensor. + """ + shape_vec = list(data.shape) + if pack_type == 'int8': + data_width = 8 + elif pack_type == 'int16': + data_width = 16 + elif pack_type == 'int32': + data_width = 32 + else: + raise RuntimeError("Unknown pack type %s" % pack_type) + assert data_width % bits == 0 + lanes = data_width // bits + + # Data must be in multiples of the data_width + assert util.get_const_int(shape_vec[-1]) % lanes == 0, "Not a multiple of word size" + shape_vec[-1] = shape_vec[-1] // lanes + oshape = tuple(shape_vec) + + def _bitpack(*indices): + ret = None + mask = tvm.const((1 << bits) - 1, pack_type) + for k in range(lanes): + idx = list(indices) + idx[-1] = idx[-1] * lanes + k + elem = data(*idx).astype(pack_type) + if k == 0: + ret = elem & mask + else: + val = (elem & mask) << tvm.const(k * bits, pack_type) + ret = ret | val + return ret + + return tvm.compute( + oshape, _bitpack, name=name, tag='bitpack') + + +@register_compute("bitpack", level=15) +def compute_bitpack(attrs, inputs, output_type, target): + lanes = attrs.lanes + dtype = inputs[0].dtype + assert dtype == "int8" + width = 8 + assert width % lanes == 0 + bits = 8 // lanes + return bitpack(inputs[0], bits, dtype) + +register_schedule("bitpack", schedule_injective) +register_pattern("bitpack", OpPattern.INJECTIVE)