Skip to content

Commit

Permalink
Vta relay bitpack (apache#34)
Browse files Browse the repository at this point in the history
* Add bitpacking

* Fix issue in Python wrapper

* Misc fixes

* Fix some bugs in expr.py
  • Loading branch information
jroesch authored and tmoreau89 committed Jan 2, 2019
1 parent 2b51e79 commit 45b4d77
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 0 deletions.
9 changes: 9 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,15 @@ struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
}
};

struct BitPackAttrs : public tvm::AttrsNode<BitPackAttrs> {
int lanes;

TVM_DECLARE_ATTRS(BitPackAttrs, "relay.attrs.BitPackAttrs") {
TVM_ATTR_FIELD(lanes)
.describe("Number of lanes packed in one element");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
137 changes: 137 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,143 @@ def realize(self):
return _expr.TempExprRealize(self)


class ExprFunctor(object):
"""
An abstract visitor defined over Expr.
Defines the default dispatch over expressions, and
implements memoization.
"""
def __init__(self):
self.memo_map = {}

# pylint: disable=no-else-return
def visit(self, expr):
from .op.op import Op
"""Apply the visitor to an expression."""
found = self.memo_map.get(expr)
if found:
return found

if isinstance(expr, Function):
res = self.visit_function(expr)
elif isinstance(expr, Call):
res = self.visit_call(expr)
elif isinstance(expr, Let):
res = self.visit_let(expr)
elif isinstance(expr, Var):
res = self.visit_var(expr)
elif isinstance(expr, GlobalVar):
res = self.visit_global_var(expr)
elif isinstance(expr, If):
res = self.visit_if(expr)
elif isinstance(expr, Tuple):
res = self.visit_tuple(expr)
elif isinstance(expr, TupleGetItem):
res = self.visit_tuple_getitem(expr)
elif isinstance(expr, Constant):
res = self.visit_constant(expr)
elif isinstance(expr, Constant):
res = self.visit_constant(expr)
elif isinstance(expr, Op):
res = self.visit_op(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))

self.memo_map[expr] = res
return res

def visit_function(self, _):
raise NotImplementedError()

def visit_let(self, _):
raise NotImplementedError()

def visit_call(self, _):
raise NotImplementedError()

def visit_var(self, _):
raise NotImplementedError()

def visit_type(self, typ):
return typ

def visit_if(self, _):
raise NotImplementedError()

def visit_tuple(self, _):
raise NotImplementedError()

def visit_tuple_getitem(self, _):
raise NotImplementedError()

def visit_constant(self, _):
raise NotImplementedError()

def visit_global_var(self, _):
raise NotImplementedError()

def visit_op(self, _):
raise NotImplementedError()


class ExprMutator(ExprFunctor):
"""
A functional visitor over Expr.
The default behavior recursively traverses the AST
and reconstructs the AST.
"""
def visit_function(self, fn):
new_body = self.visit(fn.body)
return Function(
list(fn.params),
new_body,
fn.ret_type,
fn.type_params)

def visit_let(self, let):
new_var = self.visit(let.var)
new_val = self.visit(let.value)
new_body = self.visit(let.body)
return Let(new_var, new_val, new_body)

def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return Call(new_fn, new_args, call.attrs)

def visit_var(self, rvar):
return rvar

def visit_global_id(self, global_var):
return global_var

def visit_if(self, ite):
return If(
self.visit(ite.guard),
self.visit(ite.true_b),
self.visit(ite.false_b))

def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields])

def visit_tuple_getitem(self, op):
tuple_value = self.visit(op.tuple_value)
if not tuple_value.same_as(op.tuple_value):
return TupleGetItem(tuple_value, op.index)
return op

def visit_global_var(self, gvar):
return gvar

def visit_constant(self, rconst):
return rconst

def visit_op(self, op):
return op


class TupleWrapper(object):
"""TupleWrapper.
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,22 @@ def layout_transform(data, src_layout, dst_layout):
The transformed tensor.
"""
return _make.layout_transform(data, src_layout, dst_layout)

def bitpack(data, lanes):
"""Bitpack the innermost dimension of the tensor.
Parameters
----------
data : relay.Expr
The source tensor to be packed.
lanes : int
The lanes to pack by.
Returns
-------
ret : relay.Expr
The transformed tensor.
"""
return _make.bitpack(data, lanes)

43 changes: 43 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1698,5 +1698,48 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
.set_support_level(5)
.set_attr<FTVMCompute>("FTVMCompute", LayoutTransformCompute);

bool BitPackRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
const auto* data = types[0].as<TensorTypeNode>();
const auto* out = types[num_inputs].as<TensorTypeNode>();
const BitPackAttrs* bpattrs = attrs.as<BitPackAttrs>();
CHECK(data);
CHECK(out);
CHECK(bpattrs);
CHECK_EQ(out->shape.size(), 1U);
CHECK(data->shape.size() != 0);
Array<IndexExpr> dshape = data->shape;
auto last_dim = topi::GetConstInt(dshape[dshape.size() - 1]);
CHECK_EQ(last_dim % bpattrs->lanes, 0);
auto packed_dim = tvm::Integer(last_dim / bpattrs->lanes);
dshape.Set(dshape.size() - 1, packed_dim);
reporter->Assign(types[num_inputs], TensorTypeNode::make(dshape, data->dtype));
return true;
}

Expr MakeBitPack(Expr data, int lanes) {
auto attrs = make_node<BitPackAttrs>();
attrs->lanes = std::move(lanes);
static const Op& op = Op::Get("bitpack");
return CallNode::make(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op._make.bitpack")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeBitPack, args, rv);
});

RELAY_REGISTER_OP("bitpack")
.describe(R"code(Applies bit packing to innermost dimension.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.BitPackAttrs")
.set_num_inputs(1)
.add_argument("data", "nD Tensor", "The input tensor.")
.add_type_rel("BitPack", BitPackRel)
.set_support_level(5);

} // namespace relay
} // namespace tvm
1 change: 1 addition & 0 deletions vta/python/vta/top/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from . import op
from . import relay_op
from . import bitpack
from . import relay_bitpack

0 comments on commit 45b4d77

Please sign in to comment.