Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghaohit committed Aug 24, 2020
1 parent d44be38 commit d25d0f3
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 17 deletions.
3 changes: 2 additions & 1 deletion python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@
register_broadcast_schedule("fast_tanh")
register_broadcast_schedule("fast_erf")
# a fake on_device schedule.
# this will not be used in actual computation as on_device will be removed during DeviceAnnotation pass
# this will not be used in actual computation
# as on_device will be removed during DeviceAnnotation pass
register_injective_schedule("on_device")


Expand Down
7 changes: 3 additions & 4 deletions src/relay/op/annotation/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,9 @@ RELAY_REGISTER_OP("on_device")
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[] (const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
});

Expr StopFusion(Expr data) {
Expand Down
22 changes: 10 additions & 12 deletions vta/python/vta/top/graphpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,21 +201,20 @@ def __init__(self, start=-1, end=-1):
def visit_call(self, call):
""" Visit the children. """
# First visit the children.
oshape = _get_tensor_shape(call)
odtype = _get_tensor_type(call)
input_types = [arg.checked_type for arg in call.args]
args = [self.visit(arg) for arg in call.args]

self.counter += 1
if self.counter == self.start:
ret = relay.Call(call.op, args, call.attrs)
ret = relay.annotation.on_device(ret, self.ext_ctx)
return ret
elif self.counter == self.end:

if self.counter == self.end:
ret = relay.Call(call.op, args, call.attrs)
ret = relay.annotation.on_device(ret, self.cpu_ctx)
return ret
elif self.counter > self.start and self.counter < self.end:

if self.counter > self.start and self.counter < self.end:
ret = relay.Call(call.op, args, call.attrs)

# skip the float op, i.e., float->int cast
Expand All @@ -234,11 +233,11 @@ def is_float_op(self, call):
"""
args = call.args
odtype = _get_tensor_type(call)
op = call.op

if odtype == "float32":
return True
elif op == self.cast:

if call.op == self.cast:
idtype = _get_tensor_type(args[0])
if idtype == "float32":
return True
Expand Down Expand Up @@ -566,7 +565,8 @@ def graph_pack(expr,
"""
assert isinstance(expr, relay.Function)
assert ((start_name != stop_name) or (start_name_idx is None != stop_name_idx is None) or \
(not (start_name_idx is None and stop_name_idx is None)) or (start_name_idx < stop_name_idx))
(not (start_name_idx is None and stop_name_idx is None)) \
or (start_name_idx < stop_name_idx))
expr = get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)
expr = run_opt_pass(expr, transform.InferType())
packer = ExprPack(
Expand All @@ -589,8 +589,6 @@ def graph_pack(expr,

device_annot = ExprDeviceAnnot(start=start, end=end)
expr = device_annot.visit(expr)
ret = run_opt_pass(expr, transform.InferType())
return run_opt_pass(expr, transform.InferType())

return ret
else:
return expr
return expr

0 comments on commit d25d0f3

Please sign in to comment.