diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 884288364b3d..1d42cf7c18f8 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -486,12 +486,12 @@ def _union_inputs(*graphs): input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it # to a `loc`, where inputs[loc] = sym for graph in graphs: - # input_syms: all inputs to the `graph` - name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)} # some loop_vars are inputs to `graph`, some are not name_to_loop_vars = {sym.name: sym for sym in loop_vars} # other inputs to `graph` created by cut_graph name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)} + # input_syms: all inputs to the `graph` + name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)} # also we collect the mapping from var's name to var's loc in loop_vars name_to_var_locs = {sym.name: i for i, sym in enumerate(loop_vars)} # collect arguments for each subgraph @@ -644,12 +644,12 @@ def _union_inputs(*graphs): input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it # to a `loc`, where inputs[loc] = sym for graph in graphs: - # input_syms: all inputs to the `graph` - name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)} # some input_vars are inputs to `graph`, some are not name_to_input_vars = {sym.name: sym for sym in inputs} # other inputs to `graph` created by cut_graph name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)} + # input_syms: all inputs to the `graph` + name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)} # collect arguments for each subgraph input_locs = [] # results from the second step for name in graph.list_inputs(): @@ -696,5 +696,4 @@ def _union_inputs(*graphs): else_input_locs=else_input_locs, num_outputs=then_num_outputs ) - result = _to_symbol_tuple(result, "result") - return list(result) + return [result[i] for i in range(then_num_outputs)]