diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 9e88c774728be..6621eaeeebf43 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -188,6 +188,15 @@ struct LayoutTransformAttrs : public tvm::AttrsNode { } }; +struct BitPackAttrs : public tvm::AttrsNode { + int lanes; + + TVM_DECLARE_ATTRS(BitPackAttrs, "relay.attrs.BitPackAttrs") { + TVM_ATTR_FIELD(lanes) + .describe("Number of lanes packed in one element"); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 9de0344bf6b91..318682750a9d8 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -344,6 +344,143 @@ def realize(self): return _expr.TempExprRealize(self) +class ExprFunctor(object): + """ + An abstract visitor defined over Expr. + + Defines the default dispatch over expressions, and + implements memoization. + """ + def __init__(self): + self.memo_map = {} + + # pylint: disable=no-else-return + def visit(self, expr): + from .op.op import Op + """Apply the visitor to an expression.""" + found = self.memo_map.get(expr) + if found: + return found + + if isinstance(expr, Function): + res = self.visit_function(expr) + elif isinstance(expr, Call): + res = self.visit_call(expr) + elif isinstance(expr, Let): + res = self.visit_let(expr) + elif isinstance(expr, Var): + res = self.visit_var(expr) + elif isinstance(expr, GlobalVar): + res = self.visit_global_var(expr) + elif isinstance(expr, If): + res = self.visit_if(expr) + elif isinstance(expr, Tuple): + res = self.visit_tuple(expr) + elif isinstance(expr, TupleGetItem): + res = self.visit_tuple_getitem(expr) + elif isinstance(expr, Constant): + res = self.visit_constant(expr) + elif isinstance(expr, Constant): + res = self.visit_constant(expr) + elif isinstance(expr, Op): + res = self.visit_op(expr) + else: + raise Exception("warning unhandled case: {0}".format(type(expr))) + + self.memo_map[expr] = res + return res + + def visit_function(self, _): + raise NotImplementedError() + + def visit_let(self, _): + raise NotImplementedError() + + def visit_call(self, _): + raise NotImplementedError() + + def visit_var(self, _): + raise NotImplementedError() + + def visit_type(self, typ): + return typ + + def visit_if(self, _): + raise NotImplementedError() + + def visit_tuple(self, _): + raise NotImplementedError() + + def visit_tuple_getitem(self, _): + raise NotImplementedError() + + def visit_constant(self, _): + raise NotImplementedError() + + def visit_global_var(self, _): + raise NotImplementedError() + + def visit_op(self, _): + raise NotImplementedError() + + +class ExprMutator(ExprFunctor): + """ + A functional visitor over Expr. + + The default behavior recursively traverses the AST + and reconstructs the AST. + """ + def visit_function(self, fn): + new_body = self.visit(fn.body) + return Function( + list(fn.params), + new_body, + fn.ret_type, + fn.type_params) + + def visit_let(self, let): + new_var = self.visit(let.var) + new_val = self.visit(let.value) + new_body = self.visit(let.body) + return Let(new_var, new_val, new_body) + + def visit_call(self, call): + new_fn = self.visit(call.op) + new_args = [self.visit(arg) for arg in call.args] + return Call(new_fn, new_args, call.attrs) + + def visit_var(self, rvar): + return rvar + + def visit_global_id(self, global_var): + return global_var + + def visit_if(self, ite): + return If( + self.visit(ite.guard), + self.visit(ite.true_b), + self.visit(ite.false_b)) + + def visit_tuple(self, tup): + return Tuple([self.visit(field) for field in tup.fields]) + + def visit_tuple_getitem(self, op): + tuple_value = self.visit(op.tuple_value) + if not tuple_value.same_as(op.tuple_value): + return TupleGetItem(tuple_value, op.index) + return op + + def visit_global_var(self, gvar): + return gvar + + def visit_constant(self, rconst): + return rconst + + def visit_op(self, op): + return op + + class TupleWrapper(object): """TupleWrapper. diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index bc0a42d6ab309..e34b8f6c275f8 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -449,3 +449,22 @@ def layout_transform(data, src_layout, dst_layout): The transformed tensor. """ return _make.layout_transform(data, src_layout, dst_layout) + +def bitpack(data, lanes): + """Bitpack the innermost dimension of the tensor. + + Parameters + ---------- + data : relay.Expr + The source tensor to be packed. + + lanes : int + The lanes to pack by. + + Returns + ------- + ret : relay.Expr + The transformed tensor. + """ + return _make.bitpack(data, lanes) + diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 83d6c465ff74d..76aaab77d6055 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1698,5 +1698,48 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] .set_support_level(5) .set_attr("FTVMCompute", LayoutTransformCompute); +bool BitPackRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + const auto* data = types[0].as(); + const auto* out = types[num_inputs].as(); + const BitPackAttrs* bpattrs = attrs.as(); + CHECK(data); + CHECK(out); + CHECK(bpattrs); + CHECK_EQ(out->shape.size(), 1U); + CHECK(data->shape.size() != 0); + Array dshape = data->shape; + auto last_dim = topi::GetConstInt(dshape[dshape.size() - 1]); + CHECK_EQ(last_dim % bpattrs->lanes, 0); + auto packed_dim = tvm::Integer(last_dim / bpattrs->lanes); + dshape.Set(dshape.size() - 1, packed_dim); + reporter->Assign(types[num_inputs], TensorTypeNode::make(dshape, data->dtype)); + return true; +} + +Expr MakeBitPack(Expr data, int lanes) { + auto attrs = make_node(); + attrs->lanes = std::move(lanes); + static const Op& op = Op::Get("bitpack"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.bitpack") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeBitPack, args, rv); +}); + +RELAY_REGISTER_OP("bitpack") +.describe(R"code(Applies bit packing to innermost dimension. + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.BitPackAttrs") +.set_num_inputs(1) +.add_argument("data", "nD Tensor", "The input tensor.") +.add_type_rel("BitPack", BitPackRel) +.set_support_level(5); + } // namespace relay } // namespace tvm diff --git a/vta/python/vta/top/__init__.py b/vta/python/vta/top/__init__.py index db0ddb10939a5..da2fe4af681bd 100644 --- a/vta/python/vta/top/__init__.py +++ b/vta/python/vta/top/__init__.py @@ -4,3 +4,4 @@ from . import op from . import relay_op from . import bitpack +from . import relay_bitpack