Skip to content

Commit

Permalink
add some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 28, 2020
1 parent 749522d commit 25e2895
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,10 @@ def _impl(inputs, input_types):
else:
assert "data type {} could not be parsed in conv op" % (type(weight))

# TODO: Add reshape when channel multiplier > 1. Pending PR #4644
channels = weight_shape[0]
groups = int(inputs[8])

if groups > 1:
# in torch, groups == in_channels for depth wise conv
channel_multiplier = channels // groups
new_weight_shape = (groups, channel_multiplier, weight_shape[2], weight_shape[3])
weight = _op.transform.reshape(weight, new_weight_shape)
Expand Down Expand Up @@ -722,7 +720,6 @@ def _convert_elemwise_input(data, input_type):
return data

def _wrap_const(c):
# TODO: replace this function with something already that exists above
if not isinstance(c, _expr.Expr) and not isinstance(c, list):
return _expr.const(c)
return c
Expand Down Expand Up @@ -876,7 +873,7 @@ def _get_input_types(op_node):
if in_ty.scalarType() is None:
# Tensor's type can be unknown if we use torch.jit.script(...)
# Defaults to float for now
logging.warn("Untyped Tensor found, assume it is float")
logging.warning("Untyped Tensor found, assume it is float")
input_list_types.append("float")
else:
input_list_types.append(in_ty.scalarType().lower())
Expand Down Expand Up @@ -1025,13 +1022,16 @@ def parse_params(graph, state_dict):


def parse_block(block, outputs, output_index_map):
""" Translate Torch "Block", used for prim::If and prim::Loop """
ops = _get_operator_nodes(block.nodes())
ret_name = _get_input_names(block.returnNode())[0]
return parse_operators(ops, outputs, output_index_map, ret_name)


def parse_loop(op_node, outputs, output_index_map):

""" Translate Torch prim::Loop to Relay while_loop """
# Refer to the spec for prim::Loop below
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
def get_input(index):
inode = op_node.inputsAt(index).node()
if inode.kind() == "prim::Constant":
Expand All @@ -1044,15 +1044,20 @@ def get_input(index):
return out
return _expr.const(out)

# The first input: %max_trip_count
# The second input: %initial_condition
# The rest of input: loop variables
max_loop_count = get_input(0)
init_cond = get_input(1)
num_loop_var = len(list(op_node.inputs())) - 2
init_vals = [get_input(i + 2) for i in range(num_loop_var)]

# For loop (not while loop) has always %initial_condition being 1
is_for_loop = isinstance(init_cond, _expr.Constant)

if is_for_loop:
loop_iter_dtype = "int32"
# always count from 0
init_loop_iter_val = _expr.const(0, dtype="int32")
else:
loop_iter_dtype = "bool"
Expand All @@ -1077,6 +1082,7 @@ def cond(*current_vals):
return _op.equal(i, _expr.const(True, 'bool'))

def body(*current_vals):
# Update loop variables
for (i, iname) in enumerate(inames):
outputs[output_index_map[iname]] = current_vals[i]

Expand Down Expand Up @@ -1106,6 +1112,7 @@ def get_var(name, val):
loop = while_loop(cond, [loop_iter_var] + loop_vars, body)
loop_val = loop(init_loop_iter_val, *init_vals)

# The first element is a loop counter or boolean condition, ignore it
return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]


Expand Down Expand Up @@ -1153,8 +1160,7 @@ def parse_operators(operators, outputs, output_index_map, ret_name):

def get_all_op_names(graph):
""" Return all operator names in the input graph """
nodes = list(graph.nodes())
return set(node.kind() for node in nodes)
return set(node.kind() for node in graph.nodes())


def get_graph_input_names(script_module):
Expand Down

0 comments on commit 25e2895

Please sign in to comment.