Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Bug fixes in control flow operators (#11942)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored and szha committed Aug 3, 2018
1 parent ae698f9 commit 22c97ef
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,12 +486,12 @@ def _union_inputs(*graphs):
input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it
# to a `loc`, where inputs[loc] = sym
for graph in graphs:
# input_syms: all inputs to the `graph`
name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
# some loop_vars are inputs to `graph`, some are not
name_to_loop_vars = {sym.name: sym for sym in loop_vars}
# other inputs to `graph` created by cut_graph
name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)}
# input_syms: all inputs to the `graph`
name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
# also we collect the mapping from var's name to var's loc in loop_vars
name_to_var_locs = {sym.name: i for i, sym in enumerate(loop_vars)}
# collect arguments for each subgraph
Expand Down Expand Up @@ -644,12 +644,12 @@ def _union_inputs(*graphs):
input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it
# to a `loc`, where inputs[loc] = sym
for graph in graphs:
# input_syms: all inputs to the `graph`
name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
# some input_vars are inputs to `graph`, some are not
name_to_input_vars = {sym.name: sym for sym in inputs}
# other inputs to `graph` created by cut_graph
name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)}
# input_syms: all inputs to the `graph`
name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
# collect arguments for each subgraph
input_locs = [] # results from the second step
for name in graph.list_inputs():
Expand Down Expand Up @@ -696,5 +696,4 @@ def _union_inputs(*graphs):
else_input_locs=else_input_locs,
num_outputs=then_num_outputs
)
result = _to_symbol_tuple(result, "result")
return list(result)
return [result[i] for i in range(then_num_outputs)]

0 comments on commit 22c97ef

Please sign in to comment.