Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghaohit committed Nov 2, 2020
1 parent 2169263 commit 20f0470
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions vta/python/vta/top/graphpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class ExprDeviceAnnot(ExprMutator):
---------
None
"""

def __init__(self, start=-1, end=-1):
self.ext_ctx = tvm.context("ext_dev")
self.cpu_ctx = tvm.context("cpu")
Expand Down Expand Up @@ -256,8 +257,8 @@ def is_float_op(self, call):


class ExprLocator(ExprMutator):
"""Visitor to locate op on an AST.
"""
"""Visitor to locate op on an AST."""

def __init__(self):
self.counter = -1
self.op2nodes = {}
Expand All @@ -275,10 +276,7 @@ def visit_call(self, call):
else:
self.op2nodes[(call.op, odtype)] = [self.counter]

return relay.Call(
self.visit(call.op),
args,
call.attrs)
return relay.Call(self.visit(call.op), args, call.attrs)


class ExprPack(ExprMutator):
Expand Down Expand Up @@ -514,17 +512,17 @@ def _recursion(anf, start_found, stop_found, operator_current_idx):


def graph_pack(expr,
bfactor,
cfactor,
weight_bits,
start_name="nn.max_pool2d",
stop_name="nn.global_avg_pool2d",
start_name_idx=None,
stop_name_idx=None,
count_meta=False,
device_annot=False,
annot_start_name="nn.conv2d",
annot_end_name="annotation.stop_fusion"):
bfactor,
cfactor,
weight_bits,
start_name="nn.max_pool2d",
stop_name="nn.global_avg_pool2d",
start_name_idx=None,
stop_name_idx=None,
count_meta=False,
device_annot=False,
annot_start_name="nn.conv2d",
annot_end_name="annotation.stop_fusion"):
"""Pack the graph into batch&channel packed format.
Parameters
Expand Down Expand Up @@ -576,9 +574,12 @@ def graph_pack(expr,
The transformed expression.
"""
assert isinstance(expr, relay.Function)
assert ((start_name != stop_name) or (start_name_idx is None != stop_name_idx is None) or \
(not (start_name_idx is None and stop_name_idx is None)) \
or (start_name_idx < stop_name_idx))
assert (
(start_name != stop_name)
or (start_name_idx is None != stop_name_idx is None)
or (not (start_name_idx is None and stop_name_idx is None))
or (start_name_idx < stop_name_idx)
)
expr = get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)
expr = run_opt_pass(expr, transform.InferType())
packer = ExprPack(bfactor, cfactor, weight_bits)
Expand Down

0 comments on commit 20f0470

Please sign in to comment.