Skip to content

Commit

Permalink
Add relay_bitpack.py (apache#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch authored and tmoreau89 committed Mar 22, 2019
1 parent f756586 commit d155aec
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions vta/python/vta/top/relay_bitpack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Bit packing operators"""
from __future__ import absolute_import as _abs

import tvm
from topi import util

from tvm.relay.op.op import register_compute, register_schedule
from tvm.relay.op.op import register_pattern, OpPattern
from tvm.relay.op.op import schedule_injective

def bitpack(data, bits, pack_type="int8", name="bitpack"):
"""Packs lowest dimension into format needed by VTA
Parameters
----------
pack_axis : int
index of the axis to pack in data
bit_axis : int
index of axis to place bit axis in resulting packed data
Returns
-------
packed : Tensor
The packed tensor.
"""
shape_vec = list(data.shape)
if pack_type == 'int8':
data_width = 8
elif pack_type == 'int16':
data_width = 16
elif pack_type == 'int32':
data_width = 32
else:
raise RuntimeError("Unknown pack type %s" % pack_type)
assert data_width % bits == 0
lanes = data_width // bits

# Data must be in multiples of the data_width
assert util.get_const_int(shape_vec[-1]) % lanes == 0, "Not a multiple of word size"
shape_vec[-1] = shape_vec[-1] // lanes
oshape = tuple(shape_vec)

def _bitpack(*indices):
ret = None
mask = tvm.const((1 << bits) - 1, pack_type)
for k in range(lanes):
idx = list(indices)
idx[-1] = idx[-1] * lanes + k
elem = data(*idx).astype(pack_type)
if k == 0:
ret = elem & mask
else:
val = (elem & mask) << tvm.const(k * bits, pack_type)
ret = ret | val
return ret

return tvm.compute(
oshape, _bitpack, name=name, tag='bitpack')


@register_compute("bitpack", level=15)
def compute_bitpack(attrs, inputs, output_type, target):
lanes = attrs.lanes
dtype = inputs[0].dtype
assert dtype == "int8"
width = 8
assert width % lanes == 0
bits = 8 // lanes
return bitpack(inputs[0], bits, dtype)

register_schedule("bitpack", schedule_injective)
register_pattern("bitpack", OpPattern.INJECTIVE)

0 comments on commit d155aec

Please sign in to comment.