From fa28df042ec11431b3a4fa6d093eb2a7c5c40ac8 Mon Sep 17 00:00:00 2001 From: reminisce Date: Sun, 7 May 2017 23:39:35 -0700 Subject: [PATCH] Move more front-end work to backend --- include/mxnet/c_api.h | 20 +--- include/mxnet/executor.h | 21 +++- python/mxnet/symbol.py | 195 ++++++++++++++++--------------- src/c_api/c_api_executor.cc | 202 +++++++++++++++++++++------------ src/executor/graph_executor.cc | 69 ++++++----- src/executor/graph_executor.h | 17 +-- 6 files changed, 300 insertions(+), 224 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 9466df63aade..e45994c07e08 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1086,14 +1086,10 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, const char** g2c_keys, const int* g2c_dev_types, const int* g2c_dev_ids, - const mx_uint in_arg_len, - const int* in_arg_dev_types, - const int* in_arg_dev_ids, - const mx_uint* grad_req_types, - const mx_uint aux_state_len, - const int* aux_state_dev_types, - const int* aux_state_dev_ids, - const mx_uint num_provided_args, + const mx_uint provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const mx_uint num_provided_arg_shapes, const char** provided_arg_shape_names, const mx_uint* provided_arg_shape_data, const mx_uint* provided_arg_shape_idx, @@ -1105,14 +1101,10 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, mx_uint* num_shared_data_arrays, const char*** shared_data_array_name_list, NDArrayHandle** shared_data_array_handle_list, - const mx_uint num_shared_exec_in_args, - NDArrayHandle* shared_exec_in_arg_handles, - const mx_uint num_shared_exec_arg_grads, - NDArrayHandle* shared_exec_arg_grad_handles, - const mx_uint num_shared_exec_aux_states, - NDArrayHandle* shared_exec_aux_state_handles, + mx_uint* num_in_args, NDArrayHandle** in_args, NDArrayHandle** arg_grads, + mx_uint* num_aux_states, NDArrayHandle** aux_states, ExecutorHandle shared_exec_handle, ExecutorHandle* out); diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h index 1603b8de5f53..2a305cfb7782 100644 --- a/include/mxnet/executor.h +++ b/include/mxnet/executor.h @@ -69,6 +69,21 @@ class Executor { * \return array of outputs in the executor. */ virtual const std::vector &outputs() const = 0; + /*! + * \brief get input argument map, key is arg name, value is arg's NDArray. + * \return input argument map in the executor. + */ + virtual const std::unordered_map& in_arg_map() const = 0; + /*! + * \brief get input argument graident map, key is arg name, value is gradient's NDArray. + * \return input argument gradient map in the executor. + */ + virtual const std::unordered_map& arg_grad_map() const = 0; + /*! + * \brief get aux state map, key is arg name, value is aux state's NDArray. + * \return aux state map in the executor. + */ + virtual const std::unordered_map& aux_state_map() const = 0; /*! * \brief Create an operator by bind symbol with context and arguments. * If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp. @@ -102,9 +117,9 @@ class Executor { const std::unordered_map& arg_dtype_map, const std::vector& grad_req_types, const std::unordered_set& param_names, - const std::vector& shared_exec_in_args, - const std::vector& shared_exec_arg_grads, - const std::vector& shared_exec_aux_states, + //const std::vector& shared_exec_in_args, + //const std::vector& shared_exec_arg_grads, + //const std::vector& shared_exec_aux_states, std::vector* in_args, std::vector* arg_grads, std::vector* aux_states, diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index baf5857a8099..26083dfdda71 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -1166,78 +1166,99 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, executor : mxnet.Executor The generated executor """ - listed_arguments = self.list_arguments() # read-only args - listed_aux_states = self.list_auxiliary_states() # aux states - - attr_dict = None - if type_dict is None: - attr_dict = self.attr_dict() - type_dict = {k: mx_real_t for k in listed_arguments - if k not in attr_dict or '__dtype__' not in attr_dict[k]} - - provided_arg_type_names = [] # provided type argument names - provided_arg_type_data = [] # provided types - for k, v in type_dict.items(): - v = _numpy.dtype(v).type - if v in _DTYPE_NP_TO_MX: - provided_arg_type_names.append(k) - provided_arg_type_data.append(_DTYPE_NP_TO_MX[v]) + # listed_arguments = self.list_arguments() # read-only args + # listed_aux_states = self.list_auxiliary_states() # aux states + + # attr_dict = None + # if type_dict is None: + # attr_dict = self.attr_dict() + # type_dict = {k: mx_real_t for k in listed_arguments + # if k not in attr_dict or '__dtype__' not in attr_dict[k]} + + num_provided_arg_types = 0 + provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names + provided_arg_type_data = ctypes.POINTER(mx_uint)() # provided types + if type_dict is not None: + provided_arg_type_names = [] + provided_arg_type_data = [] + for k, v in type_dict.items(): + v = _numpy.dtype(v).type + if v in _DTYPE_NP_TO_MX: + provided_arg_type_names.append(c_str(k)) + provided_arg_type_data.append(ctypes.c_int(_DTYPE_NP_TO_MX[v])) + num_provided_arg_types = mx_uint(len(provided_arg_type_names)) + provided_arg_type_names = c_array(ctypes.c_char_p, provided_arg_type_names) + provided_arg_type_data = c_array(ctypes.c_int, provided_arg_type_data) provided_arg_shape_data = [] # shape data # argument shape index in sdata, e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg provided_arg_shape_idx = [0] provided_arg_shape_names = [] # provided argument names for k, v in kwargs.items(): - if k not in listed_arguments and k not in listed_aux_states: - raise ValueError('arg name %s is not valid', k) + # if k not in listed_arguments and k not in listed_aux_states: + # raise ValueError('arg name %s is not valid', k) if isinstance(v, tuple): provided_arg_shape_names.append(c_str(k)) provided_arg_shape_data.extend(v) provided_arg_shape_idx.append(len(provided_arg_shape_data)) - req_map = {'null': 0, 'write': 1, 'add': 3} - if isinstance(grad_req, string_types): - if grad_req not in req_map: - raise ValueError('grad_req=%s is not in %s' % grad_req, str(req_map)) - grad_req_types = [mx_uint(req_map[grad_req])] * len(listed_arguments) - elif isinstance(grad_req, list): - grad_req_types = [mx_uint(req_map[item]) for item in grad_req] - elif isinstance(grad_req, dict): - grad_req_types = [] - for name in listed_arguments: - if name in grad_req: - grad_req_types.append(mx_uint(req_map[grad_req[name]])) - else: - grad_req_types.append(mx_uint(0)) - - if group2ctx is not None: - if attr_dict is None: - attr_dict = self.attr_dict() - arg_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx) - if name in attr_dict and '__ctx_group__' in attr_dict[name] - else ctx for name in listed_arguments] - aux_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx) - if name in attr_dict and '__ctx_group__' in attr_dict[name] - else ctx for name in listed_aux_states] - else: - arg_ctx = [ctx] * len(listed_arguments) - aux_ctx = [ctx] * len(listed_aux_states) - - ctx_map_keys = [] - ctx_map_dev_types = [] - ctx_map_dev_ids = [] + provided_req_type_list_len = 0 + provided_grad_req_types = ctypes.POINTER(ctypes.c_char_p)() + provided_grad_req_names = ctypes.POINTER(ctypes.c_char_p)() + if grad_req is not None: + if isinstance(grad_req, string_types): + # use provided_req_type_list_len = 0 to indicate this situation + provided_req_type_list_len = 0 + provided_grad_req_types = [c_str(grad_req)] + elif isinstance(grad_req, list): + provided_grad_req_types = [c_str(item) for item in grad_req] + provided_req_type_list_len = len(provided_grad_req_types) + elif isinstance(grad_req, dict): + provided_grad_req_names = [] + provided_grad_req_types = [] + for k, v in grad_req.items(): + provided_grad_req_names.append(c_str(k)) + provided_grad_req_types.append(c_str(v)) + provided_grad_req_names = c_array(ctypes.c_char_p, provided_grad_req_names) + provided_req_type_list_len = len(provided_grad_req_types) + provided_grad_req_types = c_array(ctypes.c_char_p, provided_grad_req_types) + + # if group2ctx is not None: + # if attr_dict is None: + # attr_dict = self.attr_dict() + # arg_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx) + # if name in attr_dict and '__ctx_group__' in attr_dict[name] + # else ctx for name in listed_arguments] + # aux_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx) + # if name in attr_dict and '__ctx_group__' in attr_dict[name] + # else ctx for name in listed_aux_states] + # else: + # arg_ctx = [ctx] * len(listed_arguments) + # aux_ctx = [ctx] * len(listed_aux_states) + + num_ctx_map_keys = mx_uint(0) + ctx_map_keys = ctypes.POINTER(ctypes.c_char_p)() + ctx_map_dev_types = ctypes.POINTER(ctypes.c_int)() + ctx_map_dev_ids = ctypes.POINTER(ctypes.c_int)() if group2ctx is not None: + ctx_map_keys = [] + ctx_map_dev_types = [] + ctx_map_dev_ids = [] for key, val in group2ctx.items(): ctx_map_keys.append(c_str(key)) ctx_map_dev_types.append(ctypes.c_int(val.device_typeid)) ctx_map_dev_ids.append(ctypes.c_int(val.device_id)) + num_ctx_map_keys = mx_uint(len(ctx_map_keys)) + ctx_map_keys = c_array(ctypes.c_char_p, ctx_map_keys) + ctx_map_dev_types = c_array(ctypes.c_int, ctx_map_dev_types) + ctx_map_dev_ids = c_array(ctypes.c_int, ctx_map_dev_ids) # prepare param names - if param_names is None: - param_names = [] - else: + param_name_list = [] + if param_names is not None: if not isinstance(param_names, list): raise ValueError('param_names in simple_bind must be a list or None') + param_name_list = [c_str(name) for name in param_names] # prepare shared_data_arrays if shared_data_arrays is None: @@ -1250,27 +1271,11 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, shared_data_array_names = [] shared_data_array_handles = [] for k, v in shared_data_arrays.items(): - shared_data_array_names.append(k) + shared_data_array_names.append(c_str(k)) shared_data_array_handles.append(v.handle) - shared_data_array_names = ctypes.POINTER(ctypes.c_char_p)(shared_data_array_names) + shared_data_array_names = c_array(ctypes.c_char_p, shared_data_array_names) num_shared_data_arrays = mx_uint(len(shared_data_array_handles)) - - # prepare shared_exec_in_args - if shared_exec is None: - num_shared_exec_in_args = 0 - shared_exec_in_arg_handles = ctypes.POINTER(NDArrayHandle)() - num_shared_exec_arg_grads = 0 - shared_exec_arg_grad_handles = ctypes.POINTER(NDArrayHandle)() - num_shared_exec_aux_states = 0 - shared_exec_aux_state_handles = ctypes.POINTER(NDArrayHandle)() - else: - shared_exec_in_arg_handles = [nd.handle for nd in shared_exec.arg_arrays] - num_shared_exec_in_args = len(shared_exec_in_arg_handles) - shared_exec_arg_grad_handles = [nd.handle if nd is not None - else None for nd in shared_exec.grad_arrays] - num_shared_exec_arg_grads = len(shared_exec_arg_grad_handles) - shared_exec_aux_state_handles = [nd.handle for nd in shared_exec.aux_arrays] - num_shared_exec_aux_states = len(shared_exec_aux_state_handles) + shared_data_array_handles = c_array(NDArrayHandle, shared_data_array_handles) # prepare shared_exec_handle shared_exec_handle = shared_exec.handle if shared_exec is not None else ExecutorHandle() @@ -1279,46 +1284,38 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, exe_handle = ExecutorHandle() # prepare current executor's in_args, arg_grads, and aux_states + num_in_args = ctypes.c_uint() in_arg_handles = ctypes.POINTER(NDArrayHandle)() arg_grad_handles = ctypes.POINTER(NDArrayHandle)() + num_aux_states = ctypes.c_uint() aux_state_handles = ctypes.POINTER(NDArrayHandle)() + check_call(_LIB.MXExecutorSimpleBind(self.handle, ctypes.c_int(ctx.device_typeid), ctypes.c_int(ctx.device_id), - mx_uint(len(ctx_map_keys)), - c_array(ctypes.c_char_p, ctx_map_keys), - c_array(ctypes.c_int, ctx_map_dev_types), - c_array(ctypes.c_int, ctx_map_dev_ids), - mx_uint(len(listed_arguments)), - c_array(ctypes.c_int, - [in_arg_ctx.device_typeid for in_arg_ctx in arg_ctx]), - c_array(ctypes.c_int, [in_arg_ctx.device_id for in_arg_ctx in arg_ctx]), - c_array(mx_uint, grad_req_types), - mx_uint(len(listed_aux_states)), - c_array(ctypes.c_int, - [aux_state_ctx.device_typeid for aux_state_ctx in aux_ctx]), - c_array(ctypes.c_int, - [aux_state_ctx.device_id for aux_state_ctx in aux_ctx]), + num_ctx_map_keys, + ctx_map_keys, + ctx_map_dev_types, + ctx_map_dev_ids, + mx_uint(provided_req_type_list_len), + provided_grad_req_names, + provided_grad_req_types, mx_uint(len(provided_arg_shape_names)), c_array(ctypes.c_char_p, provided_arg_shape_names), c_array(mx_uint, provided_arg_shape_data), c_array(mx_uint, provided_arg_shape_idx), - len(provided_arg_type_names), - c_array(ctypes.c_char_p, provided_arg_type_names), - c_array(ctypes.c_int, provided_arg_type_data), - mx_uint(len(param_names)), - c_array(ctypes.c_char_p, param_names), + num_provided_arg_types, + provided_arg_type_names, + provided_arg_type_data, + mx_uint(len(param_name_list)), + c_array(ctypes.c_char_p, param_name_list), ctypes.byref(num_shared_data_arrays), ctypes.byref(shared_data_array_names), ctypes.byref(shared_data_array_handles), - mx_uint(num_shared_exec_in_args), - shared_exec_in_arg_handles, - mx_uint(num_shared_exec_arg_grads), - shared_exec_arg_grad_handles, - mx_uint(num_shared_exec_aux_states), - shared_exec_aux_state_handles, + ctypes.byref(num_in_args), ctypes.byref(in_arg_handles), ctypes.byref(arg_grad_handles), + ctypes.byref(num_aux_states), ctypes.byref(aux_state_handles), shared_exec_handle, ctypes.byref(exe_handle))) @@ -1333,11 +1330,11 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, shared_data_arrays[k] = v # create in_args, arg_grads, and aux_states for the current executor - arg_arrays = [NDArray(NDArrayHandle(in_arg_handles[i])) for i in range(len(listed_arguments))] + arg_arrays = [NDArray(NDArrayHandle(in_arg_handles[i])) for i in range(num_in_args.value)] grad_arrays = [NDArray(NDArrayHandle(arg_grad_handles[i])) if arg_grad_handles[i] is not None - else None for i in range(len(listed_arguments))] - aux_arrays = [NDArray(NDArrayHandle(aux_state_handles[i])) for i in range(len(listed_aux_states))] + else None for i in range(num_in_args.value)] + aux_arrays = [NDArray(NDArrayHandle(aux_state_handles[i])) for i in range(num_aux_states.value)] executor = Executor(exe_handle, self, ctx, grad_req, group2ctx) executor.arg_arrays = arg_arrays diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index 7f46d06f3b2a..c6075b1d144e 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -167,11 +167,11 @@ int MXExecutorBindEX(SymbolHandle symbol_handle, * \param in_arg_len number of list_arguments * \param in_arg_dev_types device type list of list_arguments * \param in_arg_dev_ids device id list of list_arguments - * \param grad_req_types req type list of all gradients of list_arguments + * \param provided_grad_req_types req type list of all gradients of list_arguments * \param aux_state_len number of list_auxiliary_states * \param aux_state_dev_types device type list of list_auxiliary_states * \param aux_state_dev_ids device id list of list_auxiliary_states - * \param num_provided_args number of user provided in_arg and aux_state shapes + * \param num_provided_arg_shapes number of user provided in_arg and aux_state shapes * \param provided_arg_shape_names name list of provided shapes * \param provided_arg_shape_data provided shape data * \param provided_arg_shape_idx provided shape data index @@ -202,14 +202,10 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, const char** g2c_keys, const int* g2c_dev_types, const int* g2c_dev_ids, - const mx_uint in_arg_len, - const int* in_arg_dev_types, - const int* in_arg_dev_ids, - const mx_uint* grad_req_types, - const mx_uint aux_state_len, - const int* aux_state_dev_types, - const int* aux_state_dev_ids, - const mx_uint num_provided_args, + const mx_uint provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const mx_uint num_provided_arg_shapes, const char** provided_arg_shape_names, const mx_uint* provided_arg_shape_data, const mx_uint* provided_arg_shape_idx, @@ -221,69 +217,153 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, mx_uint* num_shared_data_arrays, const char*** shared_data_array_name_list, NDArrayHandle** shared_data_array_handle_list, - const mx_uint num_shared_exec_in_args, - NDArrayHandle* shared_exec_in_arg_handles, - const mx_uint num_shared_exec_arg_grads, - NDArrayHandle* shared_exec_arg_grad_handles, - const mx_uint num_shared_exec_aux_states, - NDArrayHandle* shared_exec_aux_state_handles, + mx_uint* num_in_args, NDArrayHandle** in_args, NDArrayHandle** arg_grads, + mx_uint* num_aux_states, NDArrayHandle** aux_states, ExecutorHandle shared_exec_handle, ExecutorHandle* out) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); nnvm::Symbol *sym = static_cast(symbol_handle); + + // get in_arg names + std::vector in_arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs); + std::vector aux_state_names = sym->ListInputNames(nnvm::Symbol::kAuxiliaryStates); + + // attr_dict for setting up type_dict and arg/aux ctx + std::unordered_map> attr_dict; + if (nullptr == provided_arg_dtypes || nullptr == g2c_keys) { + std::vector> attrs = sym->ListAttrsRecursive(); + for (const auto& tp : attrs) { + attr_dict[std::get<0>(tp)][std::get<1>(tp)] = std::get<2>(tp); + } + } + + // setup arg_dtype_map + std::unordered_map arg_dtype_map; + if (nullptr == provided_arg_dtypes) { // use attr_dict + for (const auto& arg_name : in_arg_names) { + const auto it = attr_dict.find(arg_name); + if (it == attr_dict.end() || !it->second.count("__ctx_group__")) { + arg_dtype_map[arg_name] = mshadow::kFloat32; + } + } + } else { // use user input type_dict + // create dtype map for in_args and aux_states + for (mx_uint i = 0; i < num_provided_arg_dtypes; ++i) { + arg_dtype_map[provided_arg_dtype_names[i]] = provided_arg_dtypes[i]; + } + } + // create default ctx Context ctx = Context::Create(static_cast(dev_type), dev_id); - // create ctx map std::map ctx_map; - for (mx_uint i = 0; i < num_g2c_keys; ++i) { - ctx_map[g2c_keys[i]] = Context::Create( - static_cast(g2c_dev_types[i]), g2c_dev_ids[i]); + std::vector in_arg_ctx_vec(in_arg_names.size(), ctx); + std::vector aux_state_ctx_vec(aux_state_names.size(), ctx); + if (nullptr != g2c_keys) { // use user input group2ctx dict + for (mx_uint i = 0; i < num_g2c_keys; ++i) { + ctx_map[g2c_keys[i]] = Context::Create( + static_cast(g2c_dev_types[i]), g2c_dev_ids[i]); + } + + // initialize in_arg_ctx_vec using group2ctx if there are any + for (size_t i = 0; i < in_arg_ctx_vec.size(); ++i) { + const auto it1 = attr_dict.find(in_arg_names[i]); + if (it1 != attr_dict.end()) { + const auto it2 = it1->second.find("__ctx_group__"); + if (it2 != it1->second.end()) { + const auto it3 = ctx_map.find(it2->second); + if (it3 != ctx_map.end()) { + in_arg_ctx_vec[i] = it3->second; + } + } + } + } + + // initialize aux_state_ctx_vec using group2ctx if there are any + for (size_t i = 0; i < aux_state_ctx_vec.size(); ++i) { + const auto it1 = attr_dict.find(aux_state_names[i]); + if (it1 != attr_dict.end()) { + const auto it2 = it1->second.find("__ctx_group__"); + if (it2 != it1->second.end()) { + const auto it3 = ctx_map.find(it2->second); + if (it3 != ctx_map.end()) { + aux_state_ctx_vec[i] = it3->second; + } + } + } + } } - // create ctxes for in_args and arg_grads - // create grad_req_type_vec for in_arg_grads - std::vector in_arg_ctx_vec; - std::vector arg_grad_ctx_vec; - std::vector grad_req_type_vec; - for (mx_uint i = 0; i < in_arg_len; ++i) { - in_arg_ctx_vec.push_back(Context::Create( - static_cast(in_arg_dev_types[i]), in_arg_dev_ids[i])); - if (grad_req_types[i] == 0) { - arg_grad_ctx_vec.push_back(Context()); - grad_req_type_vec.push_back(kNullOp); - } else { - arg_grad_ctx_vec.push_back(Context::Create( - static_cast(in_arg_dev_types[i]), in_arg_dev_ids[i])); - grad_req_type_vec.push_back(static_cast(grad_req_types[i])); + // create provided_grad_req_map + const std::map req_map = {{"null", kNullOp}, {"write", kWriteTo}, {"add", kAddTo}}; + std::unordered_map provided_grad_req_map; + std::string grad_req_type; + if (0 == provided_grad_req_list_len + && nullptr == provided_grad_req_names + && nullptr != provided_grad_req_types) { // string, grad_req='write' + CHECK_EQ(req_map.count(provided_grad_req_types[0]), 1U) + << "grad_req=" << provided_grad_req_types[0] << " is not a valid input in simple_bind; " + "only \'null\', \'write\', and \'add\' are supported"; + grad_req_type = "string"; + } else if (provided_grad_req_list_len > 0 + && nullptr == provided_grad_req_names + && nullptr != provided_grad_req_types) { // list, grad_req=['null', 'write'] + grad_req_type = "list"; + CHECK_EQ(provided_grad_req_list_len, in_arg_names.size()) + << "The length of grad_req list does not match the number of input arguments in simple_bind, " + "expected " << in_arg_names.size() << ", provided " << provided_grad_req_list_len; + } else if (provided_grad_req_list_len > 0 + && nullptr != provided_grad_req_names + && nullptr != provided_grad_req_types) { // dict, grad_req=['lhs': 'null', 'rhs': 'write'] + grad_req_type = "dict"; + for (mx_uint i = 0; i < provided_grad_req_list_len; ++i) { + CHECK_EQ(req_map.count(provided_grad_req_types[i]), 1U) + << "grad_req=" << provided_grad_req_types[i] << " is not a valid input in simple_bind; " + "only \'null\', \'write\', and \'add\' are supported"; + provided_grad_req_map[provided_grad_req_names[i]] = provided_grad_req_types[i]; } + } else { // grad_req is None + grad_req_type = "none"; } - // create ctxes for aux_states - std::vector aux_state_ctx_vec; - for (mx_uint i = 0; i < aux_state_len; ++i) { - aux_state_ctx_vec.push_back(Context::Create( - static_cast(aux_state_dev_types[i]), aux_state_dev_ids[i])); + // initialize arg_grad_ctx_vec and grad_req_type_vec + std::vector arg_grad_ctx_vec(in_arg_names.size()); + std::vector grad_req_type_vec(in_arg_names.size(), kNullOp); + if ("none" != grad_req_type) { + for (size_t i = 0; i < in_arg_names.size(); ++i) { + OpReqType cur_req = kNullOp; + if ("string" == grad_req_type) { + cur_req = req_map.at(provided_grad_req_types[0]); + } else if ("list" == grad_req_type) { + CHECK_EQ(req_map.count(provided_grad_req_types[i]), 1U) + << "grad_req=" << provided_grad_req_types[i] << " is not a valid input in simple_bind; " + "only \'null\', \'write\', and \'add\' are supported"; + cur_req = req_map.at(provided_grad_req_types[i]); + } else if ("dict" == grad_req_type) { + const auto it = provided_grad_req_map.find(in_arg_names[i]); + if (it != provided_grad_req_map.end()) { + cur_req = req_map.at(it->second); + } + } + if (kNullOp != cur_req) { + arg_grad_ctx_vec[i] = in_arg_ctx_vec[i]; + grad_req_type_vec[i] = static_cast(cur_req); + } + } } // create shape map for in_args and aux_states std::unordered_map arg_shape_map; - for (mx_uint i = 0; i < num_provided_args; ++i) { + for (mx_uint i = 0; i < num_provided_arg_shapes; ++i) { arg_shape_map[provided_arg_shape_names[i]] = TShape(provided_arg_shape_data+provided_arg_shape_idx[i], provided_arg_shape_data+provided_arg_shape_idx[i+1]); } - // create dtype map for in_args and aux_states - std::unordered_map arg_dtype_map; - for (mx_uint i = 0; i < num_provided_arg_dtypes; ++i) { - arg_dtype_map[provided_arg_dtype_names[i]] = provided_arg_dtypes[i]; - } - // create para name set for sharing data array memory std::unordered_set param_name_set; for (mx_uint i = 0; i < num_param_names; ++i) { @@ -303,31 +383,6 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, for (mx_uint i = 0; i < *num_shared_data_arrays; ++i) { shared_data_array_map[*shared_data_array_name_list[i]] = *(*shared_data_array_ptrs)[i]; } - - // create shared_exec_in_args - NDArray** shared_exec_in_arg_ptrs = - reinterpret_cast(shared_exec_in_arg_handles); - for (mx_uint i = 0; i < num_shared_exec_in_args; ++i) { - shared_exec_in_args.push_back(*shared_exec_in_arg_ptrs[i]); - } - - // create shared_exec_arg_grads - NDArray** shared_exec_arg_grad_ptrs = - reinterpret_cast(shared_exec_arg_grad_handles); - for (mx_uint i = 0; i < num_shared_exec_arg_grads; ++i) { - if (nullptr == shared_exec_arg_grad_ptrs[i]) { - shared_exec_arg_grads.push_back(NDArray()); - } else { - shared_exec_arg_grads.push_back(*shared_exec_arg_grad_ptrs[i]); - } - } - - // create shared_exec_aux_states - NDArray** shared_exec_aux_state_ptrs = - reinterpret_cast(shared_exec_aux_state_handles); - for (mx_uint i = 0; i < num_shared_exec_aux_states; ++i) { - shared_exec_aux_states.push_back(*shared_exec_aux_state_ptrs[i]); - } } // create temporary place holders for the initialized NDArrays @@ -338,8 +393,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec, aux_state_ctx_vec, arg_shape_map, arg_dtype_map, grad_req_type_vec, - param_name_set, shared_exec_in_args, shared_exec_arg_grads, - shared_exec_aux_states, &in_arg_vec, &arg_grad_vec, &aux_state_vec, + param_name_set, &in_arg_vec, &arg_grad_vec, &aux_state_vec, use_shared_data_arrays? &shared_data_array_map : nullptr, reinterpret_cast(shared_exec_handle)); diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index dac8e13e12eb..eb200ef52aa9 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -78,6 +78,18 @@ const std::vector& GraphExecutor::outputs() const { return output_arrays_; } +const std::unordered_map& GraphExecutor::in_arg_map() const { + return in_arg_map_; +} + +const std::unordered_map& GraphExecutor::arg_grad_map() const { + return arg_grad_map_; +} + +const std::unordered_map& GraphExecutor::aux_state_map() const { + return aux_state_map_; +} + nnvm::NodeEntry AttrHint(nnvm::NodeEntry src, nnvm::NodeEntry like) { static const Op* id_like = Op::Get("_identity_with_attr_like_rhs"); nnvm::NodePtr n = nnvm::Node::Create(); @@ -385,19 +397,23 @@ void GraphExecutor::Init(nnvm::Symbol symbol, nnvm::DTypeVector arg_dtypes; for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); + const std::string& arg_name = idx[nid].source->attrs.name; if (mutable_nodes.count(nid)) { CHECK_LT(aux_top, aux_states.size()); data_entry_[idx.entry_id(nid, 0)] = aux_states[aux_top]; arg_shapes.push_back(aux_states[aux_top].shape()); arg_dtypes.push_back(aux_states[aux_top].dtype()); + aux_state_map_.emplace(arg_name, aux_states[aux_top]); ++aux_top; } else { CHECK_LT(arg_top, in_args.size()); data_entry_[idx.entry_id(nid, 0)] = in_args[arg_top]; arg_shapes.push_back(in_args[arg_top].shape()); arg_dtypes.push_back(in_args[arg_top].dtype()); + in_arg_map_.emplace(arg_name, in_args[arg_top]); if (kNullOp != grad_req_types[arg_top]) { grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_store[arg_top]); + arg_grad_map_.emplace(arg_name, arg_grad_store[arg_top]); } ++arg_top; } @@ -437,15 +453,18 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, data_entry_.resize(idx.num_node_entries()); size_t arg_top = 0, aux_top = 0; auto mutable_nodes = idx.mutable_input_nodes(); + // TODO(junwu): populate in_arg_map, arg_grad_map, and aux_state_map for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const uint32_t eid = idx.entry_id(nid, 0); const TShape& inferred_shape = inferred_shapes[eid]; const int inferred_dtype = inferred_dtypes[eid]; + const std::string& arg_name = idx[nid].source->attrs.name; if (mutable_nodes.count(nid)) { // aux_states aux_state_vec->emplace_back(inferred_shape, aux_state_ctxes[aux_top], false, inferred_dtype); aux_state_vec->back() = 0; data_entry_[eid] = aux_state_vec->back(); + aux_state_map_.emplace(arg_name, aux_state_vec->back()); ++aux_top; } else { // in_args in_arg_vec->emplace_back(inferred_shape, in_arg_ctxes[arg_top], false, inferred_dtype); @@ -457,7 +476,9 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, arg_grad_vec->emplace_back(inferred_shape, arg_grad_ctxes[arg_top], false, inferred_dtype); arg_grad_vec->back() = 0; grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); + arg_grad_map_.emplace(arg_name, arg_grad_vec->back()); } + in_arg_map_.emplace(arg_name, in_arg_vec->back()); ++arg_top; } } @@ -510,52 +531,49 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const std::vector& aux_state_ctxes, const std::vector& grad_req_types, const std::unordered_set& param_names, - const bool has_shared_exec, - const std::vector& shared_exec_in_args, - const std::vector& shared_exec_arg_grads, - const std::vector& shared_exec_aux_states, + const Executor* shared_exec, std::unordered_map* shared_data_arrays, std::vector* in_arg_vec, std::vector* arg_grad_vec, std::vector* aux_state_vec) { - if (has_shared_exec) { - CHECK_EQ(in_arg_ctxes.size(), shared_exec_in_args.size()); - CHECK_EQ(arg_grad_ctxes.size(), shared_exec_arg_grads.size()); - CHECK_EQ(aux_state_ctxes.size(), shared_exec_aux_states.size()); - } // initialize in_args, arg_grads, and aux_states and populate grad_store_ data_entry_.resize(idx.num_node_entries()); size_t arg_top = 0, aux_top = 0; auto mutable_nodes = idx.mutable_input_nodes(); + const auto& shared_exec_in_args = shared_exec->in_arg_map(); + const auto& shared_exec_arg_grads = shared_exec->arg_grad_map(); + const auto& shared_exec_aux_states = shared_exec->aux_state_map(); + // TODO(junwu): populate in_arg_map, arg_grad_map, and aux_state_map for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const uint32_t eid = idx.entry_id(nid, 0); const TShape& inferred_shape = inferred_shapes[eid]; const int inferred_dtype = inferred_dtypes[eid]; + const std::string& arg_name = idx[nid].source->attrs.name; if (mutable_nodes.count(nid)) { // aux_states - if (has_shared_exec) { - const NDArray& aux_nd = shared_exec_aux_states[aux_top]; + if (nullptr != shared_exec) { + const NDArray& aux_nd = shared_exec_aux_states.at(arg_name); CHECK_EQ(inferred_shape, aux_nd.shape()) << "Inferred shape does not match shared_exec.aux_array's shape"; CHECK_EQ(inferred_dtype, aux_nd.dtype()) << "Inferred dtype does not match shared_exec.aux_array's dtype"; - aux_state_vec->push_back(aux_nd); + aux_state_vec->emplace_back(aux_nd); } else { aux_state_vec->emplace_back(inferred_shape, aux_state_ctxes[aux_top], false, inferred_dtype); aux_state_vec->back() = 0; } // if (has_shared_exec) data_entry_[eid] = aux_state_vec->back(); + aux_state_map_.emplace(arg_name, aux_state_vec->back()); ++aux_top; } else { // in_args - const std::string& arg_name = idx[nid].source->attrs.name; if (param_names.count(arg_name)) { // model parameter - if (has_shared_exec) { - const NDArray& in_arg_nd = shared_exec_in_args[arg_top]; + if (nullptr != shared_exec) { + const NDArray& in_arg_nd = shared_exec_in_args.at(arg_name); CHECK_EQ(inferred_shape, in_arg_nd.shape()) << "Inferred shape does not match shared_exec.aux_array's shape"; CHECK_EQ(inferred_dtype, in_arg_nd.dtype()) << "Inferred dtype does not match shared_exec.aux_array's dtype"; - in_arg_vec->push_back(in_arg_nd); + in_arg_vec->emplace_back(in_arg_nd); if (kNullOp == grad_req_types[arg_top]) { arg_grad_vec->emplace_back(); } else { - arg_grad_vec->push_back(shared_exec_arg_grads[arg_top]); + arg_grad_vec->emplace_back(shared_exec_arg_grads.at(arg_name)); grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); } // if (kNullOp == grad_req_types[arg_top]) } else { // !has shared_exec @@ -581,6 +599,10 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); } // if (kNullOp == grad_req_types[arg_top]) } // if (param_names.count(arg_name)) + in_arg_map_.emplace(arg_name, in_arg_vec->back()); + if (!arg_grad_vec->back().is_none()) { + arg_grad_map_.emplace(arg_name, arg_grad_vec->back()); + } data_entry_[eid] = in_arg_vec->back(); ++arg_top; } @@ -672,9 +694,6 @@ void GraphExecutor::Init(nnvm::Symbol symbol, const std::unordered_map& arg_dtype_map, const std::vector& grad_req_types, const std::unordered_set& param_names, - const std::vector& shared_exec_in_args, - const std::vector& shared_exec_arg_grads, - const std::vector& shared_exec_aux_states, std::vector* in_arg_vec, std::vector* arg_grad_vec, std::vector* aux_state_vec, @@ -719,8 +738,8 @@ void GraphExecutor::Init(nnvm::Symbol symbol, InitArguments(idx, g.GetAttr("shape"), g.GetAttr("dtype"), in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, - grad_req_types, param_names, shared_exec != nullptr, - shared_exec_in_args, shared_exec_arg_grads, shared_exec_aux_states, + grad_req_types, param_names, shared_exec, + //shared_exec_in_args, shared_exec_arg_grads, shared_exec_aux_states, shared_data_arrays, in_arg_vec, arg_grad_vec, aux_state_vec); } // The above code of shape and dtype inferences and argument @@ -1242,9 +1261,6 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, const std::unordered_map& arg_dtype_map, const std::vector& grad_req_types, const std::unordered_set& param_names, - const std::vector& shared_exec_in_args, - const std::vector& shared_exec_arg_grads, - const std::vector& shared_exec_aux_states, std::vector* in_args, std::vector* arg_grads, std::vector* aux_states, @@ -1254,8 +1270,7 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, exec->Init(symbol, default_ctx, group2ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, arg_shape_map, arg_dtype_map, - grad_req_types, param_names, shared_exec_in_args, - shared_exec_arg_grads, shared_exec_aux_states, + grad_req_types, param_names, in_args, arg_grads, aux_states, shared_data_arrays, shared_exec); return exec; diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index 9c6898924d09..3cdee2e4153e 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -49,6 +49,9 @@ class GraphExecutor : public Executor { void PartialForward(bool is_train, int step, int *step_left) override; void Backward(const std::vector &head_grads) override; const std::vector& outputs() const override; + const std::unordered_map& in_arg_map() const override; + const std::unordered_map& arg_grad_map() const override; + const std::unordered_map& aux_state_map() const override; void Print(std::ostream &os) const override; // NOLINT(*) void SetMonitorCallback(const MonitorCallback& callback) override; // Initialize the rest of attributes @@ -80,9 +83,6 @@ class GraphExecutor : public Executor { const std::unordered_map& arg_dtype_map, const std::vector& grad_req_types, const std::unordered_set& param_names, - const std::vector& shared_exec_in_args, - const std::vector& shared_exec_arg_grads, - const std::vector& shared_exec_aux_states, std::vector* in_arg_vec, std::vector* arg_grad_vec, std::vector* aux_state_vec, @@ -143,10 +143,7 @@ class GraphExecutor : public Executor { const std::vector& aux_state_ctxes, const std::vector& grad_req_types, const std::unordered_set& param_names, // DataParallelExecutorGroup.param_names - const bool has_shared_exec, // shared_exec != nullptr - const std::vector& shared_exec_in_args, - const std::vector& shared_exec_arg_grads, - const std::vector& shared_exec_aux_states, + const Executor* shared_exec, // shared_exec != nullptr std::unordered_map* shared_data_arrays, // self.shared_data_arrays[i] L636 std::vector* in_arg_vec, std::vector* arg_grad_vec, @@ -193,6 +190,12 @@ class GraphExecutor : public Executor { std::vector data_pool_; // output arrays std::vector output_arrays_; + // input argument map, key is arg name, value is arg's NDArray + std::unordered_map in_arg_map_; + // arg grad map, key is arg name, value is arg grad NDArray + std::unordered_map arg_grad_map_; + // aux state map, key is aux state name, value is aux state NDArray + std::unordered_map aux_state_map_; // gradient store std::vector > grad_store_; // array to hold head gradient.