Skip to content

Commit

Permalink
[VTA][OpenCL] add device_annot support in graphpack (apache#6125)
Browse files Browse the repository at this point in the history
* add device_annot support in graphpack

* on_device annotation

* lint

* typo

* fix lint

* fix lint
  • Loading branch information
zhanghaohit authored and trevor-m committed Jan 21, 2021
1 parent a9dd53d commit 86d32e5
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 3 deletions.
4 changes: 4 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@
register_broadcast_schedule("fast_exp")
register_broadcast_schedule("fast_tanh")
register_broadcast_schedule("fast_erf")
# a fake on_device schedule.
# this will not be used in actual computation
# as on_device will be removed during DeviceAnnotation pass
register_injective_schedule("on_device")


# zeros
Expand Down
7 changes: 6 additions & 1 deletion src/relay/op/annotation/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ RELAY_REGISTER_OP("on_device")
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
});

Expr StopFusion(Expr data) {
static const Op& op = Op::Get("annotation.stop_fusion");
Expand Down
132 changes: 130 additions & 2 deletions vta/python/vta/top/graphpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,100 @@ def _operator_idx_inc(expr, count_meta, operator_current_idx):
return operator_current_idx


class ExprDeviceAnnot(ExprMutator):
"""Visitor to perform graph annotation on an AST.
Parameters
----------
start: int
the start location to mark run on vta (inclusive)
end: int
the end location to mark run on vta (exclusive)
Returns
---------
None
"""

def __init__(self, start=-1, end=-1):
self.ext_ctx = tvm.context("ext_dev")
self.cpu_ctx = tvm.context("cpu")
self.cast = op.op.get("cast")
self.counter = -1
self.start = start
self.end = end
super().__init__()

def visit_call(self, call):
""" Visit the children. """
# First visit the children.
args = [self.visit(arg) for arg in call.args]

self.counter += 1
if self.counter == self.start:
ret = relay.Call(call.op, args, call.attrs)
ret = relay.annotation.on_device(ret, self.ext_ctx)
return ret

if self.counter == self.end:
ret = relay.Call(call.op, args, call.attrs)
ret = relay.annotation.on_device(ret, self.cpu_ctx)
return ret

if self.counter > self.start and self.counter < self.end:
ret = relay.Call(call.op, args, call.attrs)

# skip the float op, i.e., float->int cast
if self.is_float_op(call):
return ret

return relay.annotation.on_device(ret, self.ext_ctx)

return relay.Call(self.visit(call.op), args, call.attrs)

def is_float_op(self, call):
"""check if this op belongs to a float op
in general, float op's odtype is float;
a special case is float->int cast, which follow this op sequence:
multiply(float) -> round(float) -> clip(float) -> cast(int);
"""
args = call.args
odtype = _get_tensor_type(call)

if odtype == "float32":
return True

if call.op == self.cast:
idtype = _get_tensor_type(args[0])
if idtype == "float32":
return True

return False


class ExprLocator(ExprMutator):
"""Visitor to locate op on an AST."""

def __init__(self):
self.counter = -1
self.op2nodes = {}
super().__init__()

def visit_call(self, call):
""" Visit the children. """
# First visit the children.
args = [self.visit(arg) for arg in call.args]

odtype = _get_tensor_type(call)
self.counter += 1
if (call.op, odtype) in self.op2nodes:
self.op2nodes[(call.op, odtype)].append(self.counter)
else:
self.op2nodes[(call.op, odtype)] = [self.counter]

return relay.Call(self.visit(call.op), args, call.attrs)


class ExprPack(ExprMutator):
"""Visitor to perform graph packing on an AST."""

Expand Down Expand Up @@ -427,6 +521,9 @@ def graph_pack(
start_name_idx=None,
stop_name_idx=None,
count_meta=False,
device_annot=False,
annot_start_name="nn.conv2d",
annot_end_name="annotation.stop_fusion",
):
"""Pack the graph into batch&channel packed format.
Expand Down Expand Up @@ -464,16 +561,47 @@ def graph_pack(
'expr.astext(show_meta_data=False)'. When count_meta is True, the operator increase
logic would count the meta.
device_annot: boolean, optional
if we want to annoate the device_type
annot_start_name: str, optional
device annotation start node, from which we mark the nodes as `ext_dev`
annot_end_name: str, optional
device annotation end node, after which we mark the nodes as 'cpu'
Returns
-------
expr : Expr
The transformed expression.
"""
assert isinstance(expr, relay.Function)
assert (start_name != stop_name) or (start_name_idx < stop_name_idx)
assert (
(start_name != stop_name)
or (start_name_idx is None != stop_name_idx is None)
or (not (start_name_idx is None and stop_name_idx is None))
or (start_name_idx < stop_name_idx)
)
expr = get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)
expr = run_opt_pass(expr, transform.InferType())
packer = ExprPack(bfactor, cfactor, weight_bits)
expr = packer.visit(expr)
assert not packer.start_pack
return run_opt_pass(expr, transform.InferType())
expr = run_opt_pass(expr, transform.InferType())

if device_annot:
expr_locator = ExprLocator()
expr_locator.visit(expr)

annot_start = op.op.get(annot_start_name)
start = expr_locator.op2nodes[(annot_start, "int32")][0]

annot_end = op.op.get(annot_end_name)
# we mark the next op to the last stop_fusion on cpu device
end = expr_locator.op2nodes[(annot_end, "int8")][-1] + 1

device_annot = ExprDeviceAnnot(start=start, end=end)
expr = device_annot.visit(expr)
return run_opt_pass(expr, transform.InferType())

return expr

0 comments on commit 86d32e5

Please sign in to comment.