diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 41358e060e94..d4f78c1017e0 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1101,9 +1101,11 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, const int* provided_arg_dtypes, const mx_uint num_shared_arg_names, const char** shared_arg_name_list, - mx_uint* shared_buffer_len, - const char*** shared_buffer_name_list, - NDArrayHandle** shared_buffer_handle_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, mx_uint* num_in_args, NDArrayHandle** in_args, NDArrayHandle** arg_grads, diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 37fc097faf94..6b30784fb341 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -1300,7 +1300,7 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, # prepare shared_buffer if shared_buffer is None: - shared_buffer_len = mx_uint() + shared_buffer_len = ctypes.c_int(-1) shared_buffer_names = ctypes.POINTER(ctypes.c_char_p)() shared_buffer_handles = ctypes.POINTER(NDArrayHandle)() else: @@ -1312,8 +1312,10 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, shared_buffer_names.append(c_str(k)) shared_buffer_handles.append(v.handle) shared_buffer_names = c_array(ctypes.c_char_p, shared_buffer_names) - shared_buffer_len = mx_uint(len(shared_buffer_handles)) + shared_buffer_len = ctypes.c_int(len(shared_buffer_handles)) shared_buffer_handles = c_array(NDArrayHandle, shared_buffer_handles) + updated_shared_buffer_names = ctypes.POINTER(ctypes.c_char_p)() + updated_shared_buffer_handles = ctypes.POINTER(NDArrayHandle)() # prepare shared_exec_handle shared_exec_handle = shared_exec.handle if shared_exec is not None else ExecutorHandle() @@ -1348,8 +1350,10 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, mx_uint(len(shared_arg_name_list)), c_array(ctypes.c_char_p, shared_arg_name_list), ctypes.byref(shared_buffer_len), - ctypes.byref(shared_buffer_names), - ctypes.byref(shared_buffer_handles), + shared_buffer_names, + shared_buffer_handles, + ctypes.byref(updated_shared_buffer_names), + ctypes.byref(updated_shared_buffer_handles), ctypes.byref(num_in_args), ctypes.byref(in_arg_handles), ctypes.byref(arg_grad_handles), @@ -1360,11 +1364,9 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, # update shared_buffer if shared_buffer is not None: - updated_shared_buffer = [NDArray(NDArrayHandle(shared_buffer_handles[i])) - for i in range(shared_buffer_len.value)] - updated_shared_buffer_names = [py_str(shared_buffer_names[i]) - for i in range(shared_buffer_len.value)] - for k, v in zip(updated_shared_buffer_names, updated_shared_buffer): + for i in range(shared_buffer_len.value): + k = py_str(updated_shared_buffer_names[i]) + v = NDArray(NDArrayHandle(updated_shared_buffer_handles[i])) shared_buffer[k] = v # create in_args, arg_grads, and aux_states for the current executor diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index cbc6fb21bb2f..8d40514bae49 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -178,6 +178,8 @@ int MXExecutorBindEX(SymbolHandle symbol_handle, * \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec * \param shared_buffer_name_list shared data array names passed from _bind_ith_exec * \param shared_buffer_handle_list shared data array handles passed from _bind_ith_exec + * \param updated_shared_buffer_name_list updated shared data array names after binding + * \param updated_shared_buffer_handle_list updated shared data arrays after binding * \param num_in_args number of input arguments of this sym * \param in_args list_arguments associated with the current executor * \param arg_grads list of gradients of in_args associated with the current executor @@ -205,9 +207,11 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, const int* provided_arg_dtypes, const mx_uint num_shared_arg_names, const char** shared_arg_name_list, - mx_uint* shared_buffer_len, - const char*** shared_buffer_name_list, - NDArrayHandle** shared_buffer_handle_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, mx_uint* num_in_args, NDArrayHandle** in_args, NDArrayHandle** arg_grads, @@ -373,14 +377,14 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, std::vector shared_exec_in_args; std::vector shared_exec_arg_grads; std::vector shared_exec_aux_states; - bool use_shared_buffer = (nullptr != *shared_buffer_handle_list); - if (use_shared_buffer) { + bool use_shared_buffer = (*shared_buffer_len >= 0); + if (*shared_buffer_len > 0) { // create shared_buffer_map shared_buffer_map.reserve(*shared_buffer_len); - NDArray*** shared_buffer_ptrs = - reinterpret_cast(shared_buffer_handle_list); - for (mx_uint i = 0; i < *shared_buffer_len; ++i) { - shared_buffer_map[*shared_buffer_name_list[i]] = *(*shared_buffer_ptrs)[i]; + NDArray** shared_buffer_ptrs = + reinterpret_cast(shared_buffer_handle_list); + for (int i = 0; i < *shared_buffer_len; ++i) { + shared_buffer_map[shared_buffer_name_list[i]] = *(shared_buffer_ptrs[i]); } } @@ -449,8 +453,8 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, ret->ret_vec_charp.push_back(kv.first.c_str()); } *shared_buffer_len = shared_buffer_map.size(); - *shared_buffer_handle_list = &(ret->ret_handles[nd_idx]); - *shared_buffer_name_list = &(ret->ret_vec_charp[0]); + *updated_shared_buffer_handle_list = &(ret->ret_handles[nd_idx]); + *updated_shared_buffer_name_list = &(ret->ret_vec_charp[0]); } API_END(); diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 24dcb7c1fcad..c8de975ed295 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -539,10 +539,6 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, data_entry_.resize(idx.num_node_entries()); size_t arg_top = 0, aux_top = 0; auto mutable_nodes = idx.mutable_input_nodes(); - const auto& shared_exec_in_args = shared_exec->in_arg_map(); - const auto& shared_exec_arg_grads = shared_exec->arg_grad_map(); - const auto& shared_exec_aux_states = shared_exec->aux_state_map(); - // TODO(junwu): populate in_arg_map, arg_grad_map, and aux_state_map for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const uint32_t eid = idx.entry_id(nid, 0); @@ -551,7 +547,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const std::string& arg_name = idx[nid].source->attrs.name; if (mutable_nodes.count(nid)) { // aux_states if (nullptr != shared_exec) { - const NDArray& aux_nd = shared_exec_aux_states.at(arg_name); + const NDArray& aux_nd = shared_exec->aux_state_map().at(arg_name); CHECK_EQ(inferred_shape, aux_nd.shape()) << "Inferred shape does not match shared_exec.aux_array's shape." " Therefore, the allocated memory for shared_exec.aux_array cannot" @@ -574,7 +570,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, } else { // in_args if (shared_arg_names.count(arg_name)) { // model parameter if (nullptr != shared_exec) { - const NDArray& in_arg_nd = shared_exec_in_args.at(arg_name); + const NDArray& in_arg_nd = shared_exec->in_arg_map().at(arg_name); CHECK_EQ(inferred_shape, in_arg_nd.shape()) << "Inferred shape does not match shared_exec.arg_array's shape" " Therefore, the allocated memory for shared_exec.arg_array cannot" @@ -589,7 +585,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, if (kNullOp == grad_req_types[arg_top]) { arg_grad_vec->emplace_back(); } else { - arg_grad_vec->emplace_back(shared_exec_arg_grads.at(arg_name)); + arg_grad_vec->emplace_back(shared_exec->arg_grad_map().at(arg_name)); grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); } // if (kNullOp == grad_req_types[arg_top]) } else { // !has shared_exec