Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-749] Bug fixes in control flow operators #11942

Merged
merged 1 commit into from
Aug 3, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,12 +486,12 @@ def _union_inputs(*graphs):
input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it
# to a `loc`, where inputs[loc] = sym
for graph in graphs:
# input_syms: all inputs to the `graph`
name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
# some loop_vars are inputs to `graph`, some are not
name_to_loop_vars = {sym.name: sym for sym in loop_vars}
# other inputs to `graph` created by cut_graph
name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)}
# input_syms: all inputs to the `graph`
name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
# also we collect the mapping from var's name to var's loc in loop_vars
name_to_var_locs = {sym.name: i for i, sym in enumerate(loop_vars)}
# collect arguments for each subgraph
Expand Down Expand Up @@ -644,12 +644,12 @@ def _union_inputs(*graphs):
input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it
# to a `loc`, where inputs[loc] = sym
for graph in graphs:
# input_syms: all inputs to the `graph`
name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
# some input_vars are inputs to `graph`, some are not
name_to_input_vars = {sym.name: sym for sym in inputs}
# other inputs to `graph` created by cut_graph
name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)}
# input_syms: all inputs to the `graph`
name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
# collect arguments for each subgraph
input_locs = [] # results from the second step
for name in graph.list_inputs():
Expand Down Expand Up @@ -696,5 +696,4 @@ def _union_inputs(*graphs):
else_input_locs=else_input_locs,
num_outputs=then_num_outputs
)
result = _to_symbol_tuple(result, "result")
return list(result)
return [result[i] for i in range(then_num_outputs)]