diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 78b5a7262121..80627b30fcbf 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1006,29 +1006,16 @@ def _build_cache(self, *args): warnings.warn("Parameter %s is not used by any computation. " "Is this intended?"%unused, stacklevel=4) - data_indices = [] - param_indices = [] - self._cached_op_args = [] - for i, name in enumerate(input_names): - if name in data_names: - data_indices.append(i) - self._cached_op_args.append((True, data_names[name])) - else: - param_indices.append(i) - self._cached_op_args.append((False, params[name])) - flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ - self._flags - args, _ = _flatten(args, "input") try: - for is_arg, i in self._cached_op_args: - if not is_arg: - i.data() + for name in input_names: + if name in params: + params[name].data() except DeferredInitializationError: self._deferred_infer_shape(*args) - for is_arg, i in self._cached_op_args: - if not is_arg: - i._finish_deferred_init() + for name in input_names: + if name in params: + params[name]._finish_deferred_init() if self._backend: ctx = args[0].context @@ -1037,10 +1024,25 @@ def _build_cache(self, *args): for name in out.list_arguments()] aux_array = [args[data_names[name]] if name in data_names.keys() else params[name].data() for name in out.list_auxiliary_states()] + # Partition the graph. out = out.optimize_for(self._backend, arg_array, aux_array, ctx, **self._backend_opts) #update cached graph with partitioned graph self._cached_graph = data, out + + input_names = out.list_inputs() + data_indices = [] + param_indices = [] + self._cached_op_args = [] + for i, name in enumerate(input_names): + if name in data_names: + data_indices.append(i) + self._cached_op_args.append((True, data_names[name])) + else: + param_indices.append(i) + self._cached_op_args.append((False, params[name])) + flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ + self._flags self._cached_op = ndarray.CachedOp(out, flags) diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index d4ff7954c181..21e5c8190fb1 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1554,7 +1554,37 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): aux[idx] = NDArray(NDArrayHandle(new_aux_handle[i])) else: aux.append(NDArray(NDArrayHandle(new_aux_handle[i]))) - return Symbol(out) + + new_sym = Symbol(out) + + if args is not None: + new_arg_names = new_sym.list_arguments() + deleted_arg_names = set([item for item in arg_names + if item not in set(new_arg_names)]) + if isinstance(args, dict): + for a_n in deleted_arg_names: + if a_n in args: + args.pop(a_n) + elif isinstance(args, list): + indices_to_delete = [i for i, name in enumerate(arg_names) if name in deleted_arg_names] + for idx in reversed(indices_to_delete): + args.pop(idx) + + if aux is not None: + new_aux_names = new_sym.list_auxiliary_states() + deleted_aux_names = set([item for item in aux_names + if item not in set(new_aux_names)]) + + if isinstance(aux, dict): + for a_n in deleted_aux_names: + if a_n in aux: + aux.pop(a_n) + elif isinstance(args, list): + indices_to_delete = [i for i, name in enumerate(aux_names) if name in deleted_aux_names] + for idx in reversed(indices_to_delete): + aux.pop(idx) + + return new_sym # pylint: disable=too-many-locals