From f06d25ca244c96a2949a53599c13427caa1a2d57 Mon Sep 17 00:00:00 2001 From: reminisce Date: Tue, 4 Apr 2017 21:13:50 -0700 Subject: [PATCH 01/14] Initial checkin Add init functions for simple bind in graph_executor Add simple_bind c_api Add simple bind c-api Assign zeros to in_args, arg_grads, and aux_states Add simple_bind2 python interface Fix python interface bugs Interface changes Fix Fix core dump Add bind_ith_exec c_api Change simple_bind2 Fix seg fault Finish simple_bind Change _bind_ith_exec Refactor simple_bind initialization flow for bind Consolidate bind and simple_bind graph init flow Fix bug Clean up Add comments Clean up Clean up Minor correction Rename APIs in graph executor Refactor Rebase Delete deprecated functions Move more front-end work to backend Bug fix Fix failed tests Minor fix Fix lint Fix lint Revert unnecessary changes Revert Revert Clean up Fix lint Fix bind_ith_exec calling simple_bind Fix bugs for _bind_ith_exec --- include/mxnet/c_api.h | 32 ++ include/mxnet/executor.h | 32 ++ python/mxnet/module/executor_group.py | 81 +--- python/mxnet/symbol.py | 245 ++++++++--- src/c_api/c_api_executor.cc | 306 +++++++++++++ src/c_api/c_api_symbolic.cc | 1 - src/executor/graph_executor.cc | 573 ++++++++++++++++++++----- src/executor/graph_executor.h | 82 +++- tests/python/unittest/test_executor.py | 2 +- 9 files changed, 1111 insertions(+), 243 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index d2efdf585e88..90270f776456 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1149,6 +1149,38 @@ MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle, NDArrayHandle *aux_states, ExecutorHandle shared_exec, ExecutorHandle *out); + +MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const mx_uint num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const mx_uint provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const mx_uint num_provided_arg_shapes, + const char** provided_arg_shape_names, + const mx_uint* provided_arg_shape_data, + const mx_uint* provided_arg_shape_idx, + const mx_uint num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const mx_uint num_shared_arg_names, + const char** shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + mx_uint* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + mx_uint* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out); /*! * \brief set a call back to notify the completion of operation */ diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h index cf71666826ab..40bd60f5f405 100644 --- a/include/mxnet/executor.h +++ b/include/mxnet/executor.h @@ -69,6 +69,21 @@ class Executor { * \return array of outputs in the executor. */ virtual const std::vector &outputs() const = 0; + /*! + * \brief get input argument map, key is arg name, value is arg's NDArray. + * \return input argument map in the executor. + */ + virtual const std::unordered_map& in_arg_map() const = 0; + /*! + * \brief get input argument graident map, key is arg name, value is gradient's NDArray. + * \return input argument gradient map in the executor. + */ + virtual const std::unordered_map& arg_grad_map() const = 0; + /*! + * \brief get aux state map, key is arg name, value is aux state's NDArray. + * \return aux state map in the executor. + */ + virtual const std::unordered_map& aux_state_map() const = 0; /*! * \brief Create an operator by bind symbol with context and arguments. * If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp. @@ -91,6 +106,23 @@ class Executor { const std::vector &grad_req_type, const std::vector &aux_states, Executor* shared_exec = NULL); + + static Executor* SimpleBind(nnvm::Symbol symbol, + const Context& default_ctx, + const std::map& group2ctx, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::unordered_map& arg_shape_map, + const std::unordered_map& arg_dtype_map, + const std::vector& grad_req_types, + const std::unordered_set& param_names, + std::vector* in_args, + std::vector* arg_grads, + std::vector* aux_states, + std::unordered_map* + shared_data_arrays = nullptr, + Executor* shared_exec = nullptr); /*! * \brief the prototype of user-defined monitor callback */ diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py index 74640df97f16..35512bb7c60e 100755 --- a/python/mxnet/module/executor_group.py +++ b/python/mxnet/module/executor_group.py @@ -4,7 +4,6 @@ import logging from collections import OrderedDict - import numpy as np from .. import context as ctx @@ -564,6 +563,7 @@ def update_metric(self, eval_metric, labels): def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group): """Internal utility function to bind the i-th executor. + This function utilizes simple_bind python interface. """ shared_exec = None if shared_group is None else shared_group.execs[i] context = self.contexts[i] @@ -573,85 +573,14 @@ def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group): if label_shapes is not None: input_shapes.update(dict(label_shapes)) - arg_shapes, _, aux_shapes = self.symbol.infer_shape(**input_shapes) - assert arg_shapes is not None, "shape inference failed" - input_types = {x.name: x.dtype for x in data_shapes} if label_shapes is not None: input_types.update({x.name: x.dtype for x in label_shapes}) - arg_types, _, aux_types = self.symbol.infer_type(**input_types) - assert arg_types is not None, "type inference failed" - - arg_arrays = [] - grad_arrays = {} if self.for_training else None - - def _get_or_reshape(name, shared_data_arrays, arg_shape, arg_type, context, logger): - """Internal helper to get a memory block or re-use by re-shaping.""" - if name in shared_data_arrays: - arg_arr = shared_data_arrays[name] - if np.prod(arg_arr.shape) >= np.prod(arg_shape): - # nice, we can directly re-use this data blob - assert arg_arr.dtype == arg_type - arg_arr = arg_arr.reshape(arg_shape) - else: - logger.warning(('bucketing: data "%s" has a shape %s' % (name, arg_shape)) + - (', which is larger than already allocated ') + - ('shape %s' % (arg_arr.shape,)) + - ('. Need to re-allocate. Consider putting ') + - ('default_bucket_key to') + - (' be the bucket taking the largest input for better ') + - ('memory sharing.')) - arg_arr = nd.zeros(arg_shape, context, dtype=arg_type) - - # replace existing shared array because the new one is bigger - shared_data_arrays[name] = arg_arr - else: - arg_arr = nd.zeros(arg_shape, context, dtype=arg_type) - shared_data_arrays[name] = arg_arr - - return arg_arr - - # create or borrow arguments and gradients - for j in range(len(self.arg_names)): - name = self.arg_names[j] - if name in self.param_names: # model parameters - if shared_exec is None: - arg_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j]) - if self.grad_req[name] != 'null': - grad_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j]) - grad_arrays[name] = grad_arr - else: - arg_arr = shared_exec.arg_dict[name] - assert arg_arr.shape == arg_shapes[j] - assert arg_arr.dtype == arg_types[j] - if self.grad_req[name] != 'null': - grad_arrays[name] = shared_exec.grad_dict[name] - else: # data, label, or states - arg_arr = _get_or_reshape(name, shared_data_arrays, arg_shapes[j], arg_types[j], - context, self.logger) - - # data might also need grad if inputs_need_grad is True - if self.grad_req[name] != 'null': - grad_arrays[name] = _get_or_reshape('grad of ' + name, shared_data_arrays, - arg_shapes[j], arg_types[j], context, - self.logger) - - arg_arrays.append(arg_arr) - - # create or borrow aux variables - if shared_exec is None: - aux_arrays = [nd.zeros(s, context, dtype=t) for s, t in zip(aux_shapes, aux_types)] - else: - for j, arr in enumerate(shared_exec.aux_arrays): - assert aux_shapes[j] == arr.shape - assert aux_types[j] == arr.dtype - aux_arrays = shared_exec.aux_arrays[:] - - executor = self.symbol.bind(ctx=context, args=arg_arrays, - args_grad=grad_arrays, aux_states=aux_arrays, - grad_req=self.grad_req, shared_exec=shared_exec) - # Get the total bytes allocated for this executor + executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req, + type_dict=input_types, shared_arg_names=self.param_names, + shared_exec=shared_exec, + shared_buffer=shared_data_arrays, **input_shapes) self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1]) return executor diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 16cbeae36531..c6973b133ab4 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -526,7 +526,7 @@ def list_attr(self, recursive=False): pairs = ctypes.POINTER(ctypes.c_char_p)() f_handle = _LIB.MXSymbolListAttrShallow check_call(f_handle(self.handle, ctypes.byref(size), ctypes.byref(pairs))) - return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size.value)} + return {py_str(pairs[i * 2]): py_str(pairs[i * 2 + 1]) for i in range(size.value)} def attr_dict(self): """Recursively gets all attributes from the symbol and its children. @@ -552,8 +552,8 @@ def attr_dict(self): check_call(f_handle(self.handle, ctypes.byref(size), ctypes.byref(pairs))) ret = {} for i in range(size.value): - name, key = py_str(pairs[i*2]).split('$') - val = py_str(pairs[i*2+1]) + name, key = py_str(pairs[i * 2]).split('$') + val = py_str(pairs[i * 2 + 1]) if name not in ret: ret[name] = {} ret[name][key] = val @@ -776,7 +776,7 @@ def infer_type(self, *args, **kwargs): if s is not None: s = _numpy.dtype(s).type if s not in _DTYPE_NP_TO_MX: - raise TypeError('Argument need to be one of '+str(_DTYPE_NP_TO_MX)) + raise TypeError('Argument need to be one of ' + str(_DTYPE_NP_TO_MX)) sdata.append(_DTYPE_NP_TO_MX[s]) else: sdata.append(-1) @@ -885,7 +885,7 @@ def infer_shape(self, *args, **kwargs): if len(unknowns) >= 10: unknowns.append('...') break - unknowns.append('%s: %s'%(name, str(shape))) + unknowns.append('%s: %s' % (name, str(shape))) warnings.warn( "Cannot decide shape for the following arguments " + "(0s in shape means unknown dimensions). " + @@ -1012,7 +1012,7 @@ def _infer_shape_impl(self, partial, *args, **kwargs): return (arg_shapes, out_shapes, aux_shapes) else: return (None, None, None) - # pylint: enable=too-many-locals + # pylint: enable=too-many-locals def debug_str(self): """Gets a debug string of symbol. @@ -1160,12 +1160,10 @@ def _get_ndarray_inputs(arg_key, args, arg_names, allow_missing): raise TypeError('Only accept list of NDArrays or dict of str to NDArray') return c_array(NDArrayHandle, arg_handles), arg_arrays - def simple_bind(self, ctx, - grad_req='write', - type_dict=None, - group2ctx=None, - **kwargs): - """Binds current symbol to get an executor, allocate all the arguments needed. + def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, + shared_arg_names=None, shared_exec=None, shared_buffer=None, **kwargs): + """Bind current symbol to get an executor, allocate all the arguments needed. + Allows specifying data types. This function simplifies the binding procedure. You need to specify only input data shapes. Before binding the executor, the function allocates arguments and auxiliary states @@ -1175,7 +1173,7 @@ def simple_bind(self, ctx, ---------- >>> x = mx.sym.Variable('x') >>> y = mx.sym.FullyConnected(x, num_hidden=4) - >>> exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req=[]) + >>> exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req='null') >>> exe.forward() [] >>> exe.outputs[0].asnumpy() @@ -1208,6 +1206,19 @@ def simple_bind(self, ctx, group2ctx : Dict of string to mx.Context The dict mapping the `ctx_group` attribute to the context assignment. + shared_arg_names : List of string + The argument names whose `NDArray` of shared_exec can be reused for initializing + the current executor. + + shared_exec : Executor + The executor whose arg_arrays, arg_arrays, grad_arrays, and aux_arrays can be + reused for initializing the current executor. + + shared_buffer : Dict of string to `NDArray` + The dict mapping argument names to the `NDArray` that can be reused for initializing + the current executor. This buffer will be checked for reuse if one argument name + of the current executor is not found in `shared_arg_names`. + kwargs : Dict of str->shape Input shape dictionary, name->shape @@ -1216,47 +1227,166 @@ def simple_bind(self, ctx, executor : mxnet.Executor The generated executor """ - # pylint: disable=too-many-locals - if type_dict is None: - attrs = self.attr_dict() - type_dict = {k: mx_real_t for k in self.list_arguments() - if k not in attrs or '__dtype__' not in attrs[k]} - arg_shapes, _, aux_shapes = self.infer_shape(**kwargs) - arg_types, _, aux_types = self.infer_type(**type_dict) - - if arg_shapes is None or arg_types is None: - raise ValueError("Input node is not complete") - + num_provided_arg_types = 0 + provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names + provided_arg_type_data = ctypes.POINTER(mx_uint)() # provided types + if type_dict is not None: + provided_arg_type_names = [] + provided_arg_type_data = [] + for k, v in type_dict.items(): + v = _numpy.dtype(v).type + if v in _DTYPE_NP_TO_MX: + provided_arg_type_names.append(c_str(k)) + provided_arg_type_data.append(ctypes.c_int(_DTYPE_NP_TO_MX[v])) + num_provided_arg_types = mx_uint(len(provided_arg_type_names)) + provided_arg_type_names = c_array(ctypes.c_char_p, provided_arg_type_names) + provided_arg_type_data = c_array(ctypes.c_int, provided_arg_type_data) + + provided_arg_shape_data = [] # shape data + # argument shape index in sdata, + # e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg + provided_arg_shape_idx = [0] + provided_arg_shape_names = [] # provided argument names + for k, v in kwargs.items(): + # if k not in listed_arguments and k not in listed_aux_states: + # raise ValueError('arg name %s is not valid', k) + if isinstance(v, tuple): + provided_arg_shape_names.append(c_str(k)) + provided_arg_shape_data.extend(v) + provided_arg_shape_idx.append(len(provided_arg_shape_data)) + + provided_req_type_list_len = 0 + provided_grad_req_types = ctypes.POINTER(ctypes.c_char_p)() + provided_grad_req_names = ctypes.POINTER(ctypes.c_char_p)() + if grad_req is not None: + if isinstance(grad_req, string_types): + # use provided_req_type_list_len = 0 to indicate this situation + provided_req_type_list_len = 0 + provided_grad_req_types = [c_str(grad_req)] + elif isinstance(grad_req, list): + if len(grad_req) == 0: + raise RuntimeError('grad_req in simple_bind cannot be an empty list') + provided_grad_req_types = [c_str(item) for item in grad_req] + provided_req_type_list_len = len(provided_grad_req_types) + elif isinstance(grad_req, dict): + if len(grad_req) == 0: + raise RuntimeError('grad_req in simple_bind cannot be an empty dict') + provided_grad_req_names = [] + provided_grad_req_types = [] + for k, v in grad_req.items(): + provided_grad_req_names.append(c_str(k)) + provided_grad_req_types.append(c_str(v)) + provided_grad_req_names = c_array(ctypes.c_char_p, provided_grad_req_names) + provided_req_type_list_len = len(provided_grad_req_types) + provided_grad_req_types = c_array(ctypes.c_char_p, provided_grad_req_types) + + num_ctx_map_keys = mx_uint(0) + ctx_map_keys = ctypes.POINTER(ctypes.c_char_p)() + ctx_map_dev_types = ctypes.POINTER(ctypes.c_int)() + ctx_map_dev_ids = ctypes.POINTER(ctypes.c_int)() if group2ctx is not None: - attr_dict = self.attr_dict() - arg_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx) \ - if name in attr_dict and '__ctx_group__' in attr_dict[name] \ - else ctx for name in self.list_arguments()] - aux_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx) \ - if name in attr_dict and '__ctx_group__' in attr_dict[name] \ - else ctx for name in self.list_auxiliary_states()] - else: - arg_ctx = [ctx] * len(arg_shapes) - aux_ctx = [ctx] * len(aux_shapes) - - # alloc space - arg_ndarrays = [ - _nd_zeros(shape, dev, dtype=dtype) - for dtype, dev, shape in zip(arg_types, arg_ctx, arg_shapes)] - if grad_req != 'null': - grad_ndarrays = {} - for name, shape, dev, dtype in zip( - self.list_arguments(), arg_shapes, arg_ctx, arg_types): - if not isinstance(grad_req, dict) or grad_req[name] != 'null': - grad_ndarrays[name] = _nd_zeros(shape, dev, dtype=dtype) + ctx_map_keys = [] + ctx_map_dev_types = [] + ctx_map_dev_ids = [] + for key, val in group2ctx.items(): + ctx_map_keys.append(c_str(key)) + ctx_map_dev_types.append(ctypes.c_int(val.device_typeid)) + ctx_map_dev_ids.append(ctypes.c_int(val.device_id)) + num_ctx_map_keys = mx_uint(len(ctx_map_keys)) + ctx_map_keys = c_array(ctypes.c_char_p, ctx_map_keys) + ctx_map_dev_types = c_array(ctypes.c_int, ctx_map_dev_types) + ctx_map_dev_ids = c_array(ctypes.c_int, ctx_map_dev_ids) + + # prepare param names + shared_arg_name_list = [] + if shared_arg_names is not None: + if not isinstance(shared_arg_names, list): + raise ValueError('shared_arg_names in simple_bind must be a list or None') + shared_arg_name_list = [c_str(name) for name in shared_arg_names] + + # prepare shared_buffer + if shared_buffer is None: + shared_buffer_len = ctypes.c_int(-1) + shared_buffer_names = ctypes.POINTER(ctypes.c_char_p)() + shared_buffer_handles = ctypes.POINTER(NDArrayHandle)() else: - grad_ndarrays = None - - aux_ndarrays = [_nd_zeros(shape, dev, dtype=dtype) - for shape, dev, dtype in zip(aux_shapes, aux_ctx, aux_types)] - executor = self.bind(ctx, arg_ndarrays, - grad_ndarrays, grad_req, aux_ndarrays, - group2ctx=group2ctx) + if not isinstance(shared_buffer, dict): + raise ValueError('shared_buffer in simple_bind must be dict or None') + shared_buffer_names = [] + shared_buffer_handles = [] + for k, v in shared_buffer.items(): + shared_buffer_names.append(c_str(k)) + shared_buffer_handles.append(v.handle) + shared_buffer_names = c_array(ctypes.c_char_p, shared_buffer_names) + shared_buffer_len = ctypes.c_int(len(shared_buffer_handles)) + shared_buffer_handles = c_array(NDArrayHandle, shared_buffer_handles) + updated_shared_buffer_names = ctypes.POINTER(ctypes.c_char_p)() + updated_shared_buffer_handles = ctypes.POINTER(NDArrayHandle)() + + # prepare shared_exec_handle + shared_exec_handle = shared_exec.handle if shared_exec is not None else ExecutorHandle() + + # prepare current executor handle + exe_handle = ExecutorHandle() + + # prepare current executor's in_args, arg_grads, and aux_states + num_in_args = ctypes.c_uint() + in_arg_handles = ctypes.POINTER(NDArrayHandle)() + arg_grad_handles = ctypes.POINTER(NDArrayHandle)() + num_aux_states = ctypes.c_uint() + aux_state_handles = ctypes.POINTER(NDArrayHandle)() + + check_call(_LIB.MXExecutorSimpleBind(self.handle, + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + num_ctx_map_keys, + ctx_map_keys, + ctx_map_dev_types, + ctx_map_dev_ids, + mx_uint(provided_req_type_list_len), + provided_grad_req_names, + provided_grad_req_types, + mx_uint(len(provided_arg_shape_names)), + c_array(ctypes.c_char_p, provided_arg_shape_names), + c_array(mx_uint, provided_arg_shape_data), + c_array(mx_uint, provided_arg_shape_idx), + num_provided_arg_types, + provided_arg_type_names, + provided_arg_type_data, + mx_uint(len(shared_arg_name_list)), + c_array(ctypes.c_char_p, shared_arg_name_list), + ctypes.byref(shared_buffer_len), + shared_buffer_names, + shared_buffer_handles, + ctypes.byref(updated_shared_buffer_names), + ctypes.byref(updated_shared_buffer_handles), + ctypes.byref(num_in_args), + ctypes.byref(in_arg_handles), + ctypes.byref(arg_grad_handles), + ctypes.byref(num_aux_states), + ctypes.byref(aux_state_handles), + shared_exec_handle, + ctypes.byref(exe_handle))) + + # update shared_buffer + if shared_buffer is not None: + for i in range(shared_buffer_len.value): + k = py_str(updated_shared_buffer_names[i]) + v = NDArray(NDArrayHandle(updated_shared_buffer_handles[i])) + shared_buffer[k] = v + + # create in_args, arg_grads, and aux_states for the current executor + arg_arrays = [NDArray(NDArrayHandle(in_arg_handles[i])) for i in range(num_in_args.value)] + grad_arrays = [NDArray(NDArrayHandle(arg_grad_handles[i])) + if arg_grad_handles[i] is not None + else None for i in range(num_in_args.value)] + aux_arrays = [NDArray(NDArrayHandle(aux_state_handles[i])) + for i in range(num_aux_states.value)] + + executor = Executor(exe_handle, self, ctx, grad_req, group2ctx) + executor.arg_arrays = arg_arrays + executor.grad_arrays = grad_arrays + executor.aux_arrays = aux_arrays return executor def bind(self, ctx, args, args_grad=None, grad_req='write', @@ -1441,6 +1571,7 @@ def grad(self, wrt): c_wrt, ctypes.byref(handle))) return Symbol(handle) + # pylint: enable= no-member def eval(self, ctx=cpu(), **kwargs): @@ -1500,7 +1631,6 @@ def reshape(self, shape): """ return reshape(self, shape=shape) - def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, init=None, **kwargs): """Creates a symbolic variable with specified name. @@ -1565,9 +1695,11 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, ini ret._set_attr(**attr) return ret + # for back compatibility Variable = var + def Group(symbols): """Creates a symbol that contains a collection of other symbols, grouped together. @@ -1657,6 +1789,13 @@ def load_json(json_str): return Symbol(handle) +<<<<<<< HEAD +======= +# Initialize the atomic symbol in startups +_init_symbol_module(Symbol, "mxnet") + + +>>>>>>> Initial checkin # pylint: disable=no-member # pylint: disable=redefined-builtin def pow(base, exp): diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index ce765acd77bf..8d40514bae49 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -154,6 +154,312 @@ int MXExecutorBindEX(SymbolHandle symbol_handle, API_END_HANDLE_ERROR(delete exec); } +/*! + * \brief + * \param symbol_handle symbol handle + * \param dev_type default device type + * \param dev_id default device id + * \param num_g2c_keys number of group2ctx keys + * \param g2c_keys key list of group2ctx + * \param g2c_dev_types device type list of group2ctx + * \param g2c_dev_ids id list of group2ctx + * \param provided_grad_req_list_len grad_req length provided by users in front-end + * \param provided_grad_req_names grad_req names provided by users in front-end + * \param provided_grad_req_types req types provided by users in front-end + * \param num_provided_arg_shapes number of user provided in_arg and aux_state shapes + * \param provided_arg_shape_names name list of provided shapes + * \param provided_arg_shape_data provided shape data + * \param provided_arg_shape_idx provided shape data index + * \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes + * \param provided_arg_dtype_names argument name list of provided dtypes + * \param provided_arg_dtypes data of provided dtypes + * \param num_shared_arg_names number of parameter names passed from _bind_ith_exec + * \param shared_arg_name_list parameter name list passed from _bind_ith_exec + * \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec + * \param shared_buffer_name_list shared data array names passed from _bind_ith_exec + * \param shared_buffer_handle_list shared data array handles passed from _bind_ith_exec + * \param updated_shared_buffer_name_list updated shared data array names after binding + * \param updated_shared_buffer_handle_list updated shared data arrays after binding + * \param num_in_args number of input arguments of this sym + * \param in_args list_arguments associated with the current executor + * \param arg_grads list of gradients of in_args associated with the current executor + * \param num_aux_states number of aux states of this sym + * \param aux_states list_auxiliary_states associated with the current executor + * \param shared_exec_handle shared excutor handle passed from _bind_ith_exec + * \param out the handle of the executor to be created + */ +int MXExecutorSimpleBind(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const mx_uint num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const mx_uint provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const mx_uint num_provided_arg_shapes, + const char** provided_arg_shape_names, + const mx_uint* provided_arg_shape_data, + const mx_uint* provided_arg_shape_idx, + const mx_uint num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const mx_uint num_shared_arg_names, + const char** shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + mx_uint* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + mx_uint* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out) { + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + nnvm::Symbol *sym = static_cast(symbol_handle); + + // get in_arg names + std::vector in_arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs); + std::vector aux_state_names = sym->ListInputNames(nnvm::Symbol::kAuxiliaryStates); + + // attr_dict for setting up type_dict and arg/aux ctx + std::unordered_map> attr_dict; + if (nullptr == provided_arg_dtypes || nullptr == g2c_keys) { + std::vector> attrs = + sym->ListAttrsRecursive(); + attr_dict.reserve(attrs.size()); + for (const auto& tp : attrs) { + attr_dict[std::get<0>(tp)][std::get<1>(tp)] = std::get<2>(tp); + } + } + + // setup arg_dtype_map + std::unordered_map arg_dtype_map; + if (nullptr == provided_arg_dtypes) { // use attr_dict + for (const auto& arg_name : in_arg_names) { + const auto it = attr_dict.find(arg_name); + if (it == attr_dict.end() || !it->second.count("__dtype__")) { + arg_dtype_map[arg_name] = mshadow::kFloat32; + } + } + } else { // use user input type_dict + // create dtype map for in_args and aux_states + arg_dtype_map.reserve(num_provided_arg_dtypes); + for (mx_uint i = 0; i < num_provided_arg_dtypes; ++i) { + arg_dtype_map[provided_arg_dtype_names[i]] = provided_arg_dtypes[i]; + } + } + + // create default ctx + Context ctx = Context::Create(static_cast(dev_type), dev_id); + // create ctx map + std::map ctx_map; + std::vector in_arg_ctx_vec(in_arg_names.size(), ctx); + std::vector aux_state_ctx_vec(aux_state_names.size(), ctx); + if (nullptr != g2c_keys) { // use user input group2ctx dict + for (mx_uint i = 0; i < num_g2c_keys; ++i) { + ctx_map[g2c_keys[i]] = Context::Create( + static_cast(g2c_dev_types[i]), g2c_dev_ids[i]); + } + + // initialize in_arg_ctx_vec using group2ctx if there are any + for (size_t i = 0; i < in_arg_ctx_vec.size(); ++i) { + const auto it1 = attr_dict.find(in_arg_names[i]); + if (it1 != attr_dict.end()) { + const auto it2 = it1->second.find("__ctx_group__"); + if (it2 != it1->second.end()) { + const auto it3 = ctx_map.find(it2->second); + if (it3 != ctx_map.end()) { + in_arg_ctx_vec[i] = it3->second; + } + } + } + } + + // initialize aux_state_ctx_vec using group2ctx if there are any + for (size_t i = 0; i < aux_state_ctx_vec.size(); ++i) { + const auto it1 = attr_dict.find(aux_state_names[i]); + if (it1 != attr_dict.end()) { + const auto it2 = it1->second.find("__ctx_group__"); + if (it2 != it1->second.end()) { + const auto it3 = ctx_map.find(it2->second); + if (it3 != ctx_map.end()) { + aux_state_ctx_vec[i] = it3->second; + } + } + } + } + } + + // create provided_grad_req_map + const std::map req_map = + {{"null", kNullOp}, {"write", kWriteTo}, {"add", kAddTo}}; + std::unordered_map provided_grad_req_map; + std::string grad_req_type; + if (0 == provided_grad_req_list_len + && nullptr == provided_grad_req_names + && nullptr != provided_grad_req_types) { // string, grad_req='write' + CHECK_EQ(req_map.count(provided_grad_req_types[0]), 1U) + << "grad_req=" << provided_grad_req_types[0] << " is not a valid input in simple_bind; " + "only \'null\', \'write\', and \'add\' are supported"; + grad_req_type = "string"; + } else if (provided_grad_req_list_len > 0 + && nullptr == provided_grad_req_names + && nullptr != provided_grad_req_types) { // list, grad_req=['null', 'write'] + grad_req_type = "list"; + CHECK_EQ(provided_grad_req_list_len, in_arg_names.size()) + << "The length of grad_req list does not match the number of input arguments in simple_bind, " + "expected " << in_arg_names.size() << ", provided " << provided_grad_req_list_len; + } else if (provided_grad_req_list_len > 0 + && nullptr != provided_grad_req_names + && nullptr != provided_grad_req_types) { // dict, grad_req=['lhs': 'null', 'rhs': 'write'] + grad_req_type = "dict"; + provided_grad_req_map.reserve(provided_grad_req_list_len); + for (mx_uint i = 0; i < provided_grad_req_list_len; ++i) { + CHECK_EQ(req_map.count(provided_grad_req_types[i]), 1U) + << "grad_req=" << provided_grad_req_types[i] << " is not a valid input in simple_bind; " + "only \'null\', \'write\', and \'add\' are supported"; + provided_grad_req_map[provided_grad_req_names[i]] = provided_grad_req_types[i]; + } + } else { // grad_req is None + grad_req_type = "none"; + } + + // initialize arg_grad_ctx_vec and grad_req_type_vec + std::vector arg_grad_ctx_vec(in_arg_names.size(), ctx); + std::vector grad_req_type_vec(in_arg_names.size(), kNullOp); + if ("none" != grad_req_type) { + for (size_t i = 0; i < in_arg_names.size(); ++i) { + OpReqType cur_req = kNullOp; + if ("string" == grad_req_type) { + cur_req = req_map.at(provided_grad_req_types[0]); + } else if ("list" == grad_req_type) { + CHECK_EQ(req_map.count(provided_grad_req_types[i]), 1U) + << "grad_req=" << provided_grad_req_types[i] << " is not a valid input in simple_bind; " + "only \'null\', \'write\', and \'add\' are supported"; + cur_req = req_map.at(provided_grad_req_types[i]); + } else if ("dict" == grad_req_type) { + const auto it = provided_grad_req_map.find(in_arg_names[i]); + if (it != provided_grad_req_map.end()) { + cur_req = req_map.at(it->second); + } + } + if (kNullOp != cur_req) { + arg_grad_ctx_vec[i] = in_arg_ctx_vec[i]; + grad_req_type_vec[i] = static_cast(cur_req); + } + } + } + + // create shape map for in_args and aux_states + std::unordered_map arg_shape_map(num_provided_arg_shapes); + for (mx_uint i = 0; i < num_provided_arg_shapes; ++i) { + auto p = arg_shape_map.emplace(provided_arg_shape_names[i], + TShape(provided_arg_shape_data+provided_arg_shape_idx[i], + provided_arg_shape_data+provided_arg_shape_idx[i+1])); + CHECK(p.second) << "Duplicate shapes are provided for argument " + << provided_arg_shape_names[i] << " in simple_bind"; + } + + // create para name set for sharing data array memory + std::unordered_set shared_arg_name_set(num_shared_arg_names); + for (mx_uint i = 0; i < num_shared_arg_names; ++i) { + shared_arg_name_set.insert(shared_arg_name_list[i]); + } + + // create shared_buffer_map + std::unordered_map shared_buffer_map; + std::vector shared_exec_in_args; + std::vector shared_exec_arg_grads; + std::vector shared_exec_aux_states; + bool use_shared_buffer = (*shared_buffer_len >= 0); + if (*shared_buffer_len > 0) { + // create shared_buffer_map + shared_buffer_map.reserve(*shared_buffer_len); + NDArray** shared_buffer_ptrs = + reinterpret_cast(shared_buffer_handle_list); + for (int i = 0; i < *shared_buffer_len; ++i) { + shared_buffer_map[shared_buffer_name_list[i]] = *(shared_buffer_ptrs[i]); + } + } + + // create temporary place holders for the initialized NDArrays + // to be passed back to front end + std::vector in_arg_vec; + std::vector arg_grad_vec; + std::vector aux_state_vec; + + *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec, + aux_state_ctx_vec, arg_shape_map, arg_dtype_map, grad_req_type_vec, + shared_arg_name_set, &in_arg_vec, &arg_grad_vec, &aux_state_vec, + use_shared_buffer? &shared_buffer_map : nullptr, + reinterpret_cast(shared_exec_handle)); + + // copy ndarray ptrs to ret->handles so that front end + // can access them + ret->ret_handles.clear(); + ret->ret_handles.reserve(in_arg_vec.size()+arg_grad_vec.size()+aux_state_vec.size() + +shared_buffer_map.size()); + size_t nd_idx = 0; + for (const auto& nd : in_arg_vec) { + if (nd.is_none()) { + LOG(FATAL) << "Input argument NDArray cannot be un-allocated"; + } + ret->ret_handles.push_back(new NDArray(nd)); + } + if (in_arg_vec.size() > 0) { + *num_in_args = in_arg_vec.size(); + *in_args = &(ret->ret_handles[nd_idx]); + nd_idx = ret->ret_handles.size(); + } + + for (const auto& nd : arg_grad_vec) { + if (nd.is_none()) { + ret->ret_handles.push_back(nullptr); + } else { + ret->ret_handles.push_back(new NDArray(nd)); + } + } + if (arg_grad_vec.size() > 0) { + *arg_grads = &(ret->ret_handles[nd_idx]); + nd_idx = ret->ret_handles.size(); + } + + for (const auto& nd : aux_state_vec) { + if (nd.is_none()) { + LOG(FATAL) << "Auxiliary argument NDArray cannot be un-allocated"; + } + ret->ret_handles.push_back(new NDArray(nd)); + } + if (aux_state_vec.size() > 0) { + *num_aux_states = aux_state_vec.size(); + *aux_states = &(ret->ret_handles[nd_idx]); + nd_idx = ret->ret_handles.size(); + } + + if (use_shared_buffer) { + ret->ret_vec_charp.clear(); + ret->ret_vec_charp.reserve(shared_buffer_map.size()); + for (const auto kv : shared_buffer_map) { + if (kv.second.is_none()) { + LOG(FATAL) << "Shared data NDArray cannot be un-allocated"; + } + ret->ret_handles.push_back(new NDArray(kv.second)); + ret->ret_vec_charp.push_back(kv.first.c_str()); + } + *shared_buffer_len = shared_buffer_map.size(); + *updated_shared_buffer_handle_list = &(ret->ret_handles[nd_idx]); + *updated_shared_buffer_name_list = &(ret->ret_vec_charp[0]); + } + + API_END(); +} + int MXExecutorSetMonitorCallback(ExecutorHandle handle, ExecutorMonitorCallback callback, void* callback_handle) { diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 27df5b2de1f3..cad9e604df60 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -379,7 +379,6 @@ int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json) { API_END(); } - namespace mxnet { template diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 6f8f820e02dc..64e2410bedb2 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -78,6 +78,18 @@ const std::vector& GraphExecutor::outputs() const { return output_arrays_; } +const std::unordered_map& GraphExecutor::in_arg_map() const { + return in_arg_map_; +} + +const std::unordered_map& GraphExecutor::arg_grad_map() const { + return arg_grad_map_; +} + +const std::unordered_map& GraphExecutor::aux_state_map() const { + return aux_state_map_; +} + nnvm::NodeEntry AttrHint(nnvm::NodeEntry src, nnvm::NodeEntry like) { static const Op* id_like = Op::Get("_identity_with_attr_like_rhs"); nnvm::NodePtr n = nnvm::Node::Create(); @@ -178,10 +190,12 @@ inline ValueType get_node_attr( } } -nnvm::Graph GraphExecutor::InitFullGraph( - nnvm::Symbol symbol, - const std::vector& grad_req_type, - const std::vector& arg_grad_store) { +/*! + * \brief Create the graph for backward pass. + * This is triggered by both simple_bind and bind flows. + */ +nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, + const std::vector& grad_req_types) { using nnvm::NodePtr; using nnvm::NodeEntry; // initial information @@ -191,7 +205,7 @@ nnvm::Graph GraphExecutor::InitFullGraph( nnvm::Graph g; g.outputs = symbol.outputs; bool need_grad = false; - for (OpReqType req : grad_req_type) { + for (OpReqType req : grad_req_types) { if (req != kNullOp) need_grad = true; } if (!need_grad) return g; @@ -202,10 +216,8 @@ nnvm::Graph GraphExecutor::InitFullGraph( } std::vector args = symbol.ListInputs(nnvm::Symbol::kReadOnlyArgs); std::vector xs; - for (size_t i = 0; i < grad_req_type.size(); ++i) { - if (grad_req_type[i] != kNullOp) { - grad_store_.emplace_back( - std::make_pair(grad_req_type[i], arg_grad_store[i])); + for (size_t i = 0; i < grad_req_types.size(); ++i) { + if (grad_req_types[i] != kNullOp) { xs.emplace_back(NodeEntry{args[i], 0, 0}); } } @@ -242,13 +254,16 @@ nnvm::Graph GraphExecutor::InitFullGraph( return g; } -// pass to assign context to the graph +/*! + * \brief Assign context to the graph. + * This is triggered by both simple_bind and bind flows. + */ Graph AssignContext(Graph g, const Context& default_ctx, const std::map& ctx_map, - const std::vector& in_args, - const std::vector >& grad_store, - const std::vector& aux_states, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, size_t num_forward_inputs, size_t num_forward_outputs) { const auto& idx = g.indexed_graph(); @@ -257,56 +272,65 @@ Graph AssignContext(Graph g, if (ctx_map.size() == 0) { g.attrs["context"] = std::make_shared( ContextVector(idx.num_nodes(), default_ctx)); - for (const auto& x : in_args) { - CHECK(x.ctx() == default_ctx) - << "Input array is in " << x.ctx() << " while binding with ctx=" << default_ctx + for (const auto& x : in_arg_ctxes) { + CHECK(x == default_ctx) + << "Input array is in " << x << " while binding with ctx=" << default_ctx << ". All arguments must be in global context (" << default_ctx << ") unless group2ctx is specified for cross-device graph."; } - for (const auto& x : grad_store) { - CHECK(x.second.ctx() == default_ctx) - << "Gradient array is in " << x.second.ctx() << " while binding with ctx=" + for (const auto& x : arg_grad_ctxes) { + CHECK(x == default_ctx) + << "Gradient array is in " << x << " while binding with ctx=" << default_ctx << ". All gradients must be in global context (" << default_ctx << ") unless group2ctx is specified for cross-device graph."; } return g; } + // otherwise, use context assignment. - std::map ctx2id; - std::vector ctx_list; - nnvm::DeviceVector device(idx.num_nodes(), -1); - nnvm::DeviceAssignMap device_map; + std::map ctx2id; // map ctx to device id + std::vector ctx_list; // index is device id + nnvm::DeviceVector device(idx.num_nodes(), -1); // index is node id + nnvm::DeviceAssignMap device_map; // map arg name to device id + // loop through the user input ctx_map and + // populate maps and lists for (auto &kv : ctx_map) { - if (ctx2id.count(kv.second) == 0) { - ctx2id[kv.second] = static_cast(ctx_list.size()); - ctx_list.push_back(kv.second); + if (ctx2id.count(kv.second) == 0) { // if context has no device id, create one + ctx2id[kv.second] = static_cast(ctx_list.size()); // assign device id to ctx + ctx_list.push_back(kv.second); // save ctx to the list } + // assign device id to to the arg name with the corresponding ctx device_map[kv.first] = ctx2id.at(kv.second); } + // loop through all the rest of input nodes not specified + // in the ctx_map and populate maps and lists size_t arg_top = 0, aux_top = 0; for (size_t i = 0; i < num_forward_inputs; ++i) { const uint32_t nid = idx.input_nodes().at(i); Context ctx; - if (mutable_nodes.count(nid)) { - CHECK_LT(aux_top, aux_states.size()); - ctx = aux_states[aux_top].ctx(); + if (mutable_nodes.count(nid)) { // aux node is mutable + CHECK_LT(aux_top, aux_state_ctxes.size()); + ctx = aux_state_ctxes[aux_top]; ++aux_top; - } else { - CHECK_LT(arg_top, in_args.size()); - ctx = in_args[arg_top].ctx(); + } else { // regular input node is immutable + CHECK_LT(arg_top, in_arg_ctxes.size()); + ctx = in_arg_ctxes[arg_top]; ++arg_top; } - if (ctx2id.count(ctx) == 0) { - ctx2id[ctx] = static_cast(ctx_list.size()); - ctx_list.push_back(ctx); + if (ctx2id.count(ctx) == 0) { // if the current ctx is not in the map of ctx and device id + ctx2id[ctx] = static_cast(ctx_list.size()); // assign the current ctx with device id + ctx_list.push_back(ctx); // save the current ctx in the list } - device[nid] = ctx2id.at(ctx); + device[nid] = ctx2id.at(ctx); // assign device id to the current node } + + // loop through backward input nodes and populate maps and lists + // the backward input nodes is the gradient of the loss wrt the output for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i) { const uint32_t nid = idx.outputs()[i].node_id; - Context ctx = grad_store[i - num_forward_outputs].second.ctx(); + Context ctx = arg_grad_ctxes[i - num_forward_outputs]; if (ctx2id.count(ctx) == 0) { ctx2id[ctx] = static_cast(ctx_list.size()); ctx_list.push_back(ctx); @@ -318,6 +342,7 @@ Graph AssignContext(Graph g, device[nid] = devid; } } + g.attrs["device"] = std::make_shared(std::move(device)); g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy"); const auto& assigned_device = g.GetAttr("device"); @@ -334,27 +359,312 @@ Graph AssignContext(Graph g, return g; } +/*! + * \brief GraphExecutor initializer for regular bind flow in which + * input arguments and gradients are provided by users. This initializer + * uses the user provided NDArrays to populate data entries of the graph. + */ void GraphExecutor::Init(nnvm::Symbol symbol, const Context& default_ctx, const std::map& ctx_map, const std::vector& in_args, const std::vector& arg_grad_store, - const std::vector& grad_req_type, + const std::vector& grad_req_types, const std::vector& aux_states, Executor* shared_exec, const nnvm::NodeEntryMap& feed_dict) { - nnvm::Graph g = InitGraph(symbol, default_ctx, - ctx_map, in_args, arg_grad_store, - grad_req_type, aux_states, feed_dict); + // create in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes + auto get_ctx1 = [](const NDArray& nd) { return nd.ctx(); }; + auto get_ctx2 = [default_ctx](const NDArray& nd) -> Context { + if (nd.is_none()) return default_ctx; + return nd.ctx(); + }; + std::vector in_arg_ctxes(in_args.size()); + std::transform(in_args.begin(), in_args.end(), in_arg_ctxes.begin(), get_ctx1); + std::vector arg_grad_ctxes(arg_grad_store.size()); + std::transform(arg_grad_store.begin(), arg_grad_store.end(), arg_grad_ctxes.begin(), get_ctx2); + std::vector aux_state_ctxes(aux_states.size()); + std::transform(aux_states.begin(), aux_states.end(), aux_state_ctxes.begin(), get_ctx1); + + nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, + arg_grad_ctxes, aux_state_ctxes, grad_req_types); + + // create arg_shapes and arg_dtypes for shape and type inferences + const auto& idx = g.indexed_graph(); + auto mutable_nodes = idx.mutable_input_nodes(); + size_t arg_top = 0, aux_top = 0; + data_entry_.resize(idx.num_node_entries()); + nnvm::ShapeVector arg_shapes; + nnvm::DTypeVector arg_dtypes; + for (size_t i = 0; i < num_forward_inputs_; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const std::string& arg_name = idx[nid].source->attrs.name; + if (mutable_nodes.count(nid)) { + CHECK_LT(aux_top, aux_states.size()); + data_entry_[idx.entry_id(nid, 0)] = aux_states[aux_top]; + arg_shapes.push_back(aux_states[aux_top].shape()); + arg_dtypes.push_back(aux_states[aux_top].dtype()); + aux_state_map_.emplace(arg_name, aux_states[aux_top]); + ++aux_top; + } else { + CHECK_LT(arg_top, in_args.size()); + data_entry_[idx.entry_id(nid, 0)] = in_args[arg_top]; + arg_shapes.push_back(in_args[arg_top].shape()); + arg_dtypes.push_back(in_args[arg_top].dtype()); + in_arg_map_.emplace(arg_name, in_args[arg_top]); + if (kNullOp != grad_req_types[arg_top]) { + grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_store[arg_top]); + arg_grad_map_.emplace(arg_name, arg_grad_store[arg_top]); + } + ++arg_top; + } + } + + // expand arg_shapes and arg_dtypes to contain backward inputs + arg_shapes.resize(idx.input_nodes().size(), TShape()); + arg_dtypes.resize(idx.input_nodes().size(), -1); + // Infer shapes and dtypes + g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); + g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); + + // Initialize the rest attributes of the graph. + // This function can be called by regular bind + // operation flow as well. + FinishInitGraph(symbol, g, shared_exec, feed_dict); +} + +/*! + * \brief Initialize in_args, arg_grads, and aux_states + * and their data_entry_ of the executor. This function + * is called for regular simple_bind flow, i.e. no + * shared data arrays are provided. + */ +void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes, + const nnvm::DTypeVector& inferred_dtypes, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec) { + // initialize in_args, arg_grads, and aux_states + // populate grad_store_ + data_entry_.resize(idx.num_node_entries()); + size_t arg_top = 0, aux_top = 0; + auto mutable_nodes = idx.mutable_input_nodes(); + // TODO(junwu): populate in_arg_map, arg_grad_map, and aux_state_map + for (size_t i = 0; i < num_forward_inputs_; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const TShape& inferred_shape = inferred_shapes[eid]; + const int inferred_dtype = inferred_dtypes[eid]; + const std::string& arg_name = idx[nid].source->attrs.name; + if (mutable_nodes.count(nid)) { // aux_states + aux_state_vec->emplace_back(inferred_shape, aux_state_ctxes[aux_top], false, inferred_dtype); + aux_state_vec->back() = 0; + data_entry_[eid] = aux_state_vec->back(); + aux_state_map_.emplace(arg_name, aux_state_vec->back()); + ++aux_top; + } else { // in_args + in_arg_vec->emplace_back(inferred_shape, in_arg_ctxes[arg_top], false, inferred_dtype); + in_arg_vec->back() = 0; + data_entry_[eid] = in_arg_vec->back(); + if (kNullOp == grad_req_types[arg_top]) { + arg_grad_vec->emplace_back(); + } else { + arg_grad_vec->emplace_back(inferred_shape, arg_grad_ctxes[arg_top], false, inferred_dtype); + arg_grad_vec->back() = 0; + grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); + arg_grad_map_.emplace(arg_name, arg_grad_vec->back()); + } + in_arg_map_.emplace(arg_name, in_arg_vec->back()); + ++arg_top; + } + } +} + +/*! + * \brief If the requested ndarray's shape size is less than + * the corresponding shared_data_array's shape size, reuse + * the memory allocation; otherwise, create a zero ndarray. + */ +NDArray ReshapeOrCreate(const std::string& name, + const TShape& dest_arg_shape, + const int dest_arg_dtype, + const Context& ctx, + std::unordered_map* shared_buffer) { + auto it = shared_buffer->find(name); + if (it != shared_buffer->end()) { + if (it->second.shape().Size() >= dest_arg_shape.Size()) { // memory can be reused + CHECK_EQ(it->second.dtype(), dest_arg_dtype) + << "Requested arg array's dtype does not match the reusable ndarray"; + return it->second.Reshape(dest_arg_shape); + } else { + LOG(WARNING) << "Bucketing: data " << name << " has a shape " << dest_arg_shape + << ", which is larger than already allocated shape " << it->second.shape() + << ". Need to re-allocate. Consider putting default bucket key to be " + << "the bucket taking the largest input for better memory sharing."; + it->second = NDArray(dest_arg_shape, ctx, false, dest_arg_dtype); + it->second = 0; + return it->second; + } // arg_array.shape().Size() >= arg_shape.Size() + } else { + auto p = shared_buffer->emplace(name, NDArray(dest_arg_shape, ctx, false, dest_arg_dtype)); + p.first->second = 0; + return p.first->second; + } // if (it != shared_buffer->end()) +} + +/*! + * \brief Initialize in_args, arg_grads, and aux_states + * and their data_entry_ of the executor using + * shared_buffer from DataParallelExecutorGroup + * and shared_exec if available. + */ +void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes, + const nnvm::DTypeVector& inferred_dtypes, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + const Executor* shared_exec, + std::unordered_map* shared_buffer, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec) { + // initialize in_args, arg_grads, and aux_states and populate grad_store_ + data_entry_.resize(idx.num_node_entries()); + size_t arg_top = 0, aux_top = 0; + auto mutable_nodes = idx.mutable_input_nodes(); + for (size_t i = 0; i < num_forward_inputs_; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const TShape& inferred_shape = inferred_shapes[eid]; + const int inferred_dtype = inferred_dtypes[eid]; + const std::string& arg_name = idx[nid].source->attrs.name; + if (mutable_nodes.count(nid)) { // aux_states + if (nullptr != shared_exec) { + const NDArray& aux_nd = shared_exec->aux_state_map().at(arg_name); + CHECK_EQ(inferred_shape, aux_nd.shape()) + << "Inferred shape does not match shared_exec.aux_array's shape." + " Therefore, the allocated memory for shared_exec.aux_array cannot" + " be resued for creating auxilliary NDArray of the argument" + << arg_name << " for the current executor"; + CHECK_EQ(inferred_dtype, aux_nd.dtype()) + << "Inferred dtype does not match shared_exec.aux_array's dtype." + " Therefore, the allocated memory for shared_exec.aux_array cannot" + " be resued for creating auxilliary NDArray of the argument" + << arg_name << " for the current executor"; + aux_state_vec->emplace_back(aux_nd); + } else { + aux_state_vec->emplace_back(inferred_shape, aux_state_ctxes[aux_top], + false, inferred_dtype); + aux_state_vec->back() = 0; + } // if (has_shared_exec) + data_entry_[eid] = aux_state_vec->back(); + aux_state_map_.emplace(arg_name, aux_state_vec->back()); + ++aux_top; + } else { // in_args + if (shared_arg_names.count(arg_name)) { // model parameter + if (nullptr != shared_exec) { + const NDArray& in_arg_nd = shared_exec->in_arg_map().at(arg_name); + CHECK_EQ(inferred_shape, in_arg_nd.shape()) + << "Inferred shape does not match shared_exec.arg_array's shape" + " Therefore, the allocated memory for shared_exec.arg_array cannot" + " be resued for creating NDArray of the argument" + << arg_name << " for the current executor"; + CHECK_EQ(inferred_dtype, in_arg_nd.dtype()) + << "Inferred dtype does not match shared_exec.arg_array's dtype" + " Therefore, the allocated memory for shared_exec.arg_array cannot" + " be resued for creating NDArray of the argument" + << arg_name << " for the current executor"; + in_arg_vec->emplace_back(in_arg_nd); + if (kNullOp == grad_req_types[arg_top]) { + arg_grad_vec->emplace_back(); + } else { + arg_grad_vec->emplace_back(shared_exec->arg_grad_map().at(arg_name)); + grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); + } // if (kNullOp == grad_req_types[arg_top]) + } else { // !has shared_exec + in_arg_vec->emplace_back(inferred_shape, in_arg_ctxes[arg_top], false, inferred_dtype); + in_arg_vec->back() = 0; + if (kNullOp == grad_req_types[arg_top]) { + arg_grad_vec->emplace_back(); + } else { + arg_grad_vec->emplace_back(inferred_shape, arg_grad_ctxes[arg_top], + false, inferred_dtype); + arg_grad_vec->back() = 0; + grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); + } // if (kNullOp == grad_req_types[arg_top]) + } // if (has_shared_exec) + } else { // !shared_arg_names.count(arg_name) + in_arg_vec->emplace_back(ReshapeOrCreate(arg_name, inferred_shape, inferred_dtype, + in_arg_ctxes[arg_top], shared_buffer)); + if (kNullOp == grad_req_types[arg_top]) { + arg_grad_vec->emplace_back(); + } else { + arg_grad_vec->emplace_back(ReshapeOrCreate("grad of " + arg_name, inferred_shape, + inferred_dtype, arg_grad_ctxes[arg_top], + shared_buffer)); + grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); + } // if (kNullOp == grad_req_types[arg_top]) + } // if (shared_arg_names.count(arg_name)) + in_arg_map_.emplace(arg_name, in_arg_vec->back()); + if (!arg_grad_vec->back().is_none()) { + arg_grad_map_.emplace(arg_name, arg_grad_vec->back()); + } + data_entry_[eid] = in_arg_vec->back(); + ++arg_top; + } + } +} + +/*! + * \brief Finish graph initialization after shape and dtype inferences. + * This function is used by both simple_bind and bind flows. + */ +void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, + nnvm::Graph g, + Executor* shared_exec, + const nnvm::NodeEntryMap& feed_dict) { + const auto& idx = g.indexed_graph(); + for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { + data_entry_[idx.entry_id(idx.outputs()[j])] = grad_store_[j - num_forward_outputs_].second; + } + + { + // memory allocator + const int kBadStorageID = -1; + const int kExternalStorageID = -2; + nnvm::StorageVector arg_storage_id(idx.num_node_entries(), kBadStorageID); + for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { + arg_storage_id[idx.entry_id(idx.outputs()[j])] = kExternalStorageID; + } + for (const auto& kv : feed_dict) { + uint32_t eid = idx.entry_id(kv.first); + data_entry_[eid] = kv.second; + arg_storage_id[eid] = kExternalStorageID; + } + g.attrs["storage"] = std::make_shared(std::move(arg_storage_id)); + g = nnvm::ApplyPass(g, "PlanMemory"); + } + g = DetectInplaceAddTo(g); + g.attrs["saved_opr"] = std::make_shared(std::move(saved_opr_)); g = AttachOpExecs(g); g = AttachOpResources(g); graph_ = std::move(g); + if (shared_exec != nullptr) { this->InitDataEntryMemory(&(dynamic_cast(shared_exec)->data_pool_)); } else { this->InitDataEntryMemory(nullptr); } + { // initialize output arrays auto& idx = graph_.indexed_graph(); @@ -374,22 +684,111 @@ void GraphExecutor::Init(nnvm::Symbol symbol, this->InitOpSegs(); } +/*! + * \brief GraphExecutor initializer for simple bind flow in + * which only certain input shapes and dtypes are provided by users. + * The initializer uses these shapes and dtypes to perform + * shape and dtype inferences, and then create NDArrays + * to populate data entries of the graph. The created NDArrays + * for in_args, arg_grads and aux_states are passed to the + * front end to attach the created executor. + * In front end, if the simple_bind flow is trigger by + * _bind_ith_exec, the shared data arrays of DataParallelExecutorGroup + * and shared executor will be taken into account in creating + * NDArrays for in_args, arg_grads, and aux_states for resuing + * already allocated memory. + */ +void GraphExecutor::Init(nnvm::Symbol symbol, + const Context& default_ctx, + const std::map& ctx_map, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::unordered_map& arg_shape_map, + const std::unordered_map& arg_dtype_map, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec, + std::unordered_map* shared_buffer, + Executor* shared_exec, + const nnvm::NodeEntryMap& feed_dict) { + nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes, + aux_state_ctxes, grad_req_types); + // The following code of shape and dtype inferences and argument + // initialization is for simple_bind only. Regular bind operation + // should do this differently. + + // Initialize arg_shapes and arg_dtypes for shape and type inferences. + // It contains all in_args and aux_states' shapes and types in a certain order. + const nnvm::IndexedGraph& idx = g.indexed_graph(); + nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape()); + nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1); + for (size_t i = 0; i < num_forward_inputs_; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const std::string& name = idx[nid].source->attrs.name; + auto it1 = arg_shape_map.find(name); + if (arg_shape_map.end() != it1) { + arg_shapes[i] = it1->second; + } + auto it2 = arg_dtype_map.find(name); + if (arg_dtype_map.end() != it2) { + arg_dtypes[i] = it2->second; + } + } + g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); + g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); + + // Create in_args, arg_grads, and aux_states using + // the inferred shapes and dtypes. + if (nullptr == shared_buffer) { // regular simple bind + InitArguments(idx, g.GetAttr("shape"), + g.GetAttr("dtype"), + in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, + grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec); + } else { // simple bind using shared data arrays and shared_exec + InitArguments(idx, g.GetAttr("shape"), + g.GetAttr("dtype"), + in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, + grad_req_types, shared_arg_names, shared_exec, + shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec); + } + // The above code of shape and dtype inferences and argument + // initialization is for simple_bind only. Regular bind operation + // should do this differently. + + // Initialize the rest attributes of the graph. + // This function can be called by regular bind + // operation flow as well. + FinishInitGraph(symbol, g, shared_exec, feed_dict); +} + +/*! + * \brief This function is triggered by both simple_bind + * and bind flows. + * Setup backward graph, create device and context + * attributes in the graph, and calculate the number + * of forward nodes. + */ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, const Context& default_ctx, const std::map& ctx_map, - const std::vector& in_args, - const std::vector& arg_grad_store, - const std::vector& grad_req_type, - const std::vector& aux_states, - const nnvm::NodeEntryMap& feed_dict) { + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types) { // setup gradient - nnvm::Graph g = InitFullGraph(symbol, grad_req_type, arg_grad_store); + nnvm::Graph g = InitFullGraph(symbol, grad_req_types); + + // create "device" and "context" attrs for the graph g = AssignContext(g, default_ctx, ctx_map, - in_args, - grad_store_, - aux_states, + in_arg_ctxes, + arg_grad_ctxes, + aux_state_ctxes, num_forward_inputs_, num_forward_outputs_); + const auto& idx = g.indexed_graph(); // get number of nodes used in forward pass num_forward_nodes_ = 0; @@ -397,55 +796,6 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, num_forward_nodes_ = std::max( num_forward_nodes_, static_cast(idx.outputs()[i].node_id + 1)); } - // Setup data entry, shape and type. - data_entry_.resize(idx.num_node_entries()); - auto mutable_nodes = idx.mutable_input_nodes(); - nnvm::ShapeVector arg_shapes; - nnvm::DTypeVector arg_types; - size_t arg_top = 0, aux_top = 0; - for (size_t i = 0; i < num_forward_inputs_; ++i) { - const uint32_t nid = idx.input_nodes().at(i); - if (mutable_nodes.count(nid)) { - CHECK_LT(aux_top, aux_states.size()); - data_entry_[idx.entry_id(nid, 0)] = aux_states[aux_top]; - arg_shapes.push_back(aux_states[aux_top].shape()); - arg_types.push_back(aux_states[aux_top].dtype()); - ++aux_top; - } else { - CHECK_LT(arg_top, in_args.size()); - data_entry_[idx.entry_id(nid, 0)] = in_args[arg_top]; - arg_shapes.push_back(in_args[arg_top].shape()); - arg_types.push_back(in_args[arg_top].dtype()); - ++arg_top; - } - } - for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { - data_entry_[idx.entry_id(idx.outputs()[j])] - = grad_store_[j - num_forward_outputs_].second; - } - arg_shapes.resize(idx.input_nodes().size(), TShape()); - arg_types.resize(idx.input_nodes().size(), -1); - // other initializations - g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); - g = nnvm::pass::InferType(g, arg_types, "__dtype__"); - - { - // memory allocator - const int kBadStorageID = -1; - const int kExternalStorageID = -2; - nnvm::StorageVector arg_storage_id(idx.num_node_entries(), kBadStorageID); - for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { - arg_storage_id[idx.entry_id(idx.outputs()[j])] = kExternalStorageID; - } - for (const auto& kv : feed_dict) { - uint32_t eid = idx.entry_id(kv.first); - data_entry_[eid] = kv.second; - arg_storage_id[eid] = kExternalStorageID; - } - g.attrs["storage"] = std::make_shared(std::move(arg_storage_id)); - g = nnvm::ApplyPass(g, "PlanMemory"); - } - g = DetectInplaceAddTo(g); return g; } @@ -913,6 +1263,31 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, } } // namespace exec +Executor *Executor::SimpleBind(nnvm::Symbol symbol, + const Context& default_ctx, + const std::map& group2ctx, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::unordered_map& arg_shape_map, + const std::unordered_map& arg_dtype_map, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + std::vector* in_args, + std::vector* arg_grads, + std::vector* aux_states, + std::unordered_map* shared_buffer, + Executor* shared_exec) { + auto exec = new exec::GraphExecutor(); + exec->Init(symbol, default_ctx, group2ctx, + in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, + arg_shape_map, arg_dtype_map, + grad_req_types, shared_arg_names, + in_args, arg_grads, aux_states, + shared_buffer, shared_exec); + return exec; +} + Executor *Executor::Bind(nnvm::Symbol symbol, const Context& default_ctx, const std::map& group2ctx, diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index d9c3a3e6aa47..d5a4e8c3aa6c 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -49,19 +49,47 @@ class GraphExecutor : public Executor { void PartialForward(bool is_train, int step, int *step_left) override; void Backward(const std::vector &head_grads) override; const std::vector& outputs() const override; + const std::unordered_map& in_arg_map() const override; + const std::unordered_map& arg_grad_map() const override; + const std::unordered_map& aux_state_map() const override; void Print(std::ostream &os) const override; // NOLINT(*) void SetMonitorCallback(const MonitorCallback& callback) override; - // initialized the executor + // Initialize the rest of attributes + // after setting up arguments. + void FinishInitGraph(nnvm::Symbol symbol, nnvm::Graph g, + Executor* shared_exec = nullptr, + const nnvm::NodeEntryMap& feed_dict + = nnvm::NodeEntryMap()); + + // initialize executor for bind void Init(nnvm::Symbol symbol, const Context& default_ctx, const std::map& ctx_map, const std::vector& in_args, const std::vector& arg_grad_store, - const std::vector& grad_req_type, + const std::vector& grad_req_types, const std::vector& aux_states, Executor* shared_exec = nullptr, const nnvm::NodeEntryMap& feed_dict = nnvm::NodeEntryMap()); + // initialize executor for simple bind + void Init(nnvm::Symbol symbol, + const Context& default_ctx, + const std::map& ctx_map, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::unordered_map& arg_shape_map, + const std::unordered_map& arg_dtype_map, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec, + std::unordered_map* shared_buffer = nullptr, + Executor* shared_exec = nullptr, + const nnvm::NodeEntryMap& feed_dict + = nnvm::NodeEntryMap()); protected: // Information about operational node @@ -94,21 +122,43 @@ class GraphExecutor : public Executor { // list of op executors std::vector exec_list; }; - - // internal initialization of the graph. + // Initialize in_args, arg_grads, and aux_states + void InitArguments(const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes, + const nnvm::DTypeVector& inferred_dtypes, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec); + // Initialize in_args, arg_grads and aux_states with + // shared_buffer and shared_exec + void InitArguments(const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes, + const nnvm::DTypeVector& inferred_dtypes, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + const Executor* shared_exec, + std::unordered_map* shared_buffer, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec); + // internal initialization of the graph for simple bind Graph InitGraph(nnvm::Symbol symbol, const Context& default_ctx, const std::map& ctx_map, - const std::vector& in_args, - const std::vector& arg_grad_store, - const std::vector& grad_req_type, - const std::vector& aux_states, - const nnvm::NodeEntryMap& feed_dict - = nnvm::NodeEntryMap()); - // initialize the full graph, including gradient. + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types); + // intialize the full graph for simple bind, including gradient Graph InitFullGraph(nnvm::Symbol symbol, - const std::vector& grad_req_type, - const std::vector& arg_grad_store); + const std::vector& grad_req_types); // initialize the cached operator void InitCachedOps(); // initialize the opr segments for bulk exec @@ -140,6 +190,12 @@ class GraphExecutor : public Executor { std::vector data_pool_; // output arrays std::vector output_arrays_; + // input argument map, key is arg name, value is arg's NDArray + std::unordered_map in_arg_map_; + // arg grad map, key is arg name, value is arg grad NDArray + std::unordered_map arg_grad_map_; + // aux state map, key is aux state name, value is aux state NDArray + std::unordered_map aux_state_map_; // gradient store std::vector > grad_store_; // array to hold head gradient. diff --git a/tests/python/unittest/test_executor.py b/tests/python/unittest/test_executor.py index b190b2898843..c1cc013b81c0 100644 --- a/tests/python/unittest/test_executor.py +++ b/tests/python/unittest/test_executor.py @@ -121,7 +121,7 @@ def test_reshape(): x = mx.sym.Variable('x') y = mx.sym.FullyConnected(x, num_hidden=4) - exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req=[]) + exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req='null') exe.arg_arrays[0][:] = 1 exe.arg_arrays[1][:] = mx.nd.ones((4,4)) exe.arg_arrays[2][:] = 0 From 2bb952a9e3298ecaa41298de175f350dbbdc2389 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Tue, 23 May 2017 14:16:01 -0700 Subject: [PATCH 02/14] Add unit test (#1) * Add unit test * Fix * Small fix --- python/mxnet/test_utils.py | 23 ++++++++ tests/python/unittest/test_module.py | 84 ++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 6089edae5a56..e839c4b722be 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1020,3 +1020,26 @@ def set_env_var(key, val, default_val=""): prev_val = os.environ.get(key, default_val) os.environ[key] = val return prev_val + +def same_array(array1, array2): + """Check whether two NDArrays sharing the same memory block + + Parameters + ---------- + + array1 : NDArray + First NDArray to be checked + array2 : NDArray + Second NDArray to be checked + + Returns + ------- + bool + Whether two NDArrays share the same memory + """ + array1[:] += 1 + if not same(array1.asnumpy(), array2.asnumpy()): + array1[:] -= 1 + return False + array1[:] -= 1 + return same(array1.asnumpy(), array2.asnumpy()) \ No newline at end of file diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 5508a37c9567..4dde5a60b8e3 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -2,6 +2,7 @@ import mxnet.ndarray as nd import numpy as np from functools import reduce +from mxnet.module.executor_group import DataParallelExecutorGroup def test_module_dtype(): dtype = np.float16 @@ -254,6 +255,88 @@ def mean_abs(x): break assert(mon_result_counts == [2, 2, 1, 6, 6, 4]) +def test_executor_group(): + def test_create_exec_group(exec_grp_shared, exec_grp_created, + shared_arg_names, extra_input=[], extra_arg=[]): + # Test shared data arrays + for i in range(len(exec_grp_shared.execs)): + for data_name, array in exec_grp_shared.shared_data_arrays[i].items(): + assert data_name in exec_grp_created.shared_data_arrays[i], \ + "Shared input data '%s' is not in " \ + "shared_data_arrays of created executor group." % (data_name) + assert mx.test_utils.same_array(array, exec_grp_created.shared_data_arrays[i][data_name]), \ + "Shared input data '%s' does not share memory." % (data_name) + for input_name in extra_input: + assert input_name in exec_grp_created.execs[i].arg_dict, \ + "Extra input data '%s' is not in arg_dict of created executor group." % (input_name) + + # Test shared argument arrays and gradient arrays + for i in range(len(exec_grp_shared.execs)): + exec1 = exec_grp_shared.execs[i] + exec2 = exec_grp_created.execs[i] + for arg_name in shared_arg_names: + assert arg_name in exec2.arg_dict, \ + "Shared argument '%s' is not in arg_dict of created executor group." % (arg_name) + assert mx.test_utils.same_array(exec1.arg_dict[arg_name], exec2.arg_dict[arg_name]), \ + "Shared argument '%s' does not share memory." % (arg_name) + for arg_name in extra_arg: + assert arg_name in exec2.arg_dict, \ + "Extra argument '%s' is not in arg_dict of created executor group." % (arg_name) + for arg_name, grad in exec_grp_shared.grad_req.items(): + assert grad == exec_grp_created.grad_req[arg_name], \ + "Gradient requirements for shared argument '%s' are inconsistent. " \ + "Shared executor group requires '%s' while created executor group requires '%s'" \ + %(arg_name, grad, exec_grp_created.grad_req[arg_name]) + for arg_name in shared_arg_names: + assert arg_name in exec2.grad_dict, \ + "Shared argument gradient '%s' is not in " \ + "grad_dict of created executor group." % (arg_name) + assert mx.test_utils.same_array(exec1.grad_dict[arg_name], exec2.grad_dict[arg_name]), \ + "Shared argument gradient '%s' does not sharing memory." % (arg_name) + + contexts = [mx.cpu(0), mx.cpu(1)] + workload = [1] * len(contexts) + batch_size = 16 + num_hidden = 4 + data_shapes1 = [('data1', (batch_size, 10))] + data_shapes2 = [('data1', (batch_size, 10)), ('data2', (batch_size, 10))] + label_shapes = [('softmax_label', (batch_size,))] + + data1 = mx.sym.Variable('data1') + data2 = mx.sym.Variable('data2') + fc1 = mx.sym.FullyConnected(data=data1, name='fc1', num_hidden=num_hidden) + mlp1 = mx.sym.SoftmaxOutput(data=fc1, name='softmax') + fc1 = mx.sym.FullyConnected(data=data1 + data2, name='fc1', num_hidden=num_hidden) + fc2 = mx.sym.FullyConnected(data=fc1, name='fc2', num_hidden=num_hidden) + mlp2 = mx.sym.SoftmaxOutput(data=fc2, name='softmax') + + arg_names = mlp1.list_arguments() + input_names = [name[0] for name in data_shapes1] + [name[0] for name in label_shapes] + shared_arg_names = [name for name in arg_names if name not in input_names] + + exec_group1 = DataParallelExecutorGroup(symbol=mlp1, contexts=contexts, + workload=workload, data_shapes=data_shapes1, + label_shapes=label_shapes, param_names=shared_arg_names, + for_training=True, inputs_need_grad=False) + + # Test two executor groups with the same symbol sharing memory + exec_group2 = DataParallelExecutorGroup(symbol=mlp1, contexts=contexts, + workload=workload, data_shapes=data_shapes1, + label_shapes=label_shapes, param_names=shared_arg_names, + for_training=True, inputs_need_grad=False, + shared_group=exec_group1) + test_create_exec_group(exec_group1, exec_group2, shared_arg_names) + + # Test two executor groups with different symbol sharing memory + exec_group3 = DataParallelExecutorGroup(symbol=mlp2, contexts=contexts, + workload=workload, data_shapes=data_shapes2, + label_shapes=label_shapes, param_names=shared_arg_names, + for_training=True, inputs_need_grad=False, + shared_group=exec_group1) + extra_input = ['data2'] + extra_arg = ['fc2_weight', 'fc2_bias'] + test_create_exec_group(exec_group1, exec_group3, shared_arg_names, extra_input, extra_arg) + if __name__ == '__main__': test_module_dtype() test_module_input_grads() @@ -263,3 +346,4 @@ def mean_abs(x): test_module_layout() test_module_switch_bucket() test_monitor() + test_executor_group() From d266d1859566cbfcbca09951d755cc1806f47e8c Mon Sep 17 00:00:00 2001 From: reminisce Date: Tue, 23 May 2017 14:40:14 -0700 Subject: [PATCH 03/14] Fix lint --- python/mxnet/test_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index e839c4b722be..7b9a402e6dba 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1042,4 +1042,5 @@ def same_array(array1, array2): array1[:] -= 1 return False array1[:] -= 1 - return same(array1.asnumpy(), array2.asnumpy()) \ No newline at end of file + return same(array1.asnumpy(), array2.asnumpy()) + From 95cd217a419dfc915e6c9ac4dbf107ab7fda92ff Mon Sep 17 00:00:00 2001 From: reminisce Date: Tue, 23 May 2017 14:48:45 -0700 Subject: [PATCH 04/14] Fix lint --- python/mxnet/test_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 7b9a402e6dba..3ab44d0917a1 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1043,4 +1043,3 @@ def same_array(array1, array2): return False array1[:] -= 1 return same(array1.asnumpy(), array2.asnumpy()) - From 6fc1886bb8974e3425cdf6161f64cad6f209e90a Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 25 May 2017 20:31:09 -0700 Subject: [PATCH 05/14] Fix bugs of missing ndarrays in shared_buffer --- src/c_api/c_api_executor.cc | 7 +- tests/python/unittest/test_module.py | 150 +++++++++++++++++---------- 2 files changed, 99 insertions(+), 58 deletions(-) diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index 8d40514bae49..87085befef3d 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -443,14 +443,17 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, } if (use_shared_buffer) { + ret->ret_vec_str.clear(); + ret->ret_vec_str.reserve(shared_buffer_map.size()); ret->ret_vec_charp.clear(); ret->ret_vec_charp.reserve(shared_buffer_map.size()); - for (const auto kv : shared_buffer_map) { + for (const auto& kv : shared_buffer_map) { if (kv.second.is_none()) { LOG(FATAL) << "Shared data NDArray cannot be un-allocated"; } ret->ret_handles.push_back(new NDArray(kv.second)); - ret->ret_vec_charp.push_back(kv.first.c_str()); + ret->ret_vec_str.emplace_back(kv.first); + ret->ret_vec_charp.push_back(ret->ret_vec_str.back().c_str()); } *shared_buffer_len = shared_buffer_map.size(); *updated_shared_buffer_handle_list = &(ret->ret_handles[nd_idx]); diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 4dde5a60b8e3..9f3cff8e1265 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -4,6 +4,7 @@ from functools import reduce from mxnet.module.executor_group import DataParallelExecutorGroup + def test_module_dtype(): dtype = np.float16 dshape = (3, 8, 7) @@ -46,6 +47,7 @@ def test_module_input_grads(): assert np.all(b_grad == 2), b_grad assert np.all(c_grad == 3), c_grad + def test_module_layout(): sym = mx.sym.Variable('data') sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC') @@ -63,6 +65,7 @@ def test_module_layout(): for x in mod.get_outputs(merge_multi_context=False)[0]: assert x.shape == hdshape + def test_save_load(): def dict_equ(a, b): assert set(a) == set(b) @@ -102,6 +105,7 @@ def dict_equ(a, b): dict_equ(mod.get_params()[0], mod2.get_params()[0]) dict_equ(mod._kvstore._updater.states, mod2._updater.states) + def test_module_reshape(): data = mx.sym.Variable('data') sym = mx.sym.FullyConnected(data, num_hidden=20, name='fc') @@ -128,6 +132,7 @@ def test_module_reshape(): assert mod.get_outputs()[0].shape == dshape assert (mod.get_params()[0]['fc_bias'].asnumpy() == -3).all() + def test_module_states(): stack = mx.rnn.SequentialRNNCell() for i in range(2): @@ -154,6 +159,7 @@ def test_module_states(): for x1, x2 in zip(out1, out2): assert not mx.test_utils.almost_equal(x1.asnumpy(), x2.asnumpy(), rtol=1e-3) + def test_module_switch_bucket(): vocab_dim = 5000 num_hidden = 100 @@ -208,6 +214,7 @@ def create_bucketing_module(key): #the default bucket is expected to reuse the bytes allocated assert total_bytes_after == total_bytes_before + def test_monitor(): # data iter mx.random.seed(11) @@ -255,87 +262,118 @@ def mean_abs(x): break assert(mon_result_counts == [2, 2, 1, 6, 6, 4]) + def test_executor_group(): - def test_create_exec_group(exec_grp_shared, exec_grp_created, - shared_arg_names, extra_input=[], extra_arg=[]): + def get_rnn_sym(num_layers, num_words, num_hidden, num_embed, seq_len): + stack = mx.rnn.SequentialRNNCell() + for i in range(num_layers): + stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_' % i)) + data = mx.sym.Variable('data') + label = mx.sym.Variable('softmax_label') + embed = mx.sym.Embedding(data=data, input_dim=num_words, + output_dim=num_embed, name='embed') + + stack.reset() + outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True) + + pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden)) + pred = mx.sym.FullyConnected(data=pred, num_hidden=num_words, name='pred') + + label = mx.sym.Reshape(label, shape=(-1,)) + pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax') + return pred + + def test_shared_exec_group(exec_grp_shared, exec_grp_created, shared_arg_names=None, extra_args=None): # Test shared data arrays for i in range(len(exec_grp_shared.execs)): + # test same shared_data_arrays for two exec groups + shared_data_array1 = exec_grp_shared.shared_data_arrays[i] + shared_data_array2 = exec_grp_created.shared_data_arrays[i] + if extra_args is not None: + assert len(shared_data_array1) == len(extra_args),\ + "exec_grp_shared.shared_data_arrays[%d] should have same number of args as extra_args" + assert len(shared_data_array1) == len(shared_data_array2),\ + "length of shared_data_array of the shared executor group not equal to the created executor group" + for k, v in shared_data_array1.items(): + if extra_args is not None: + assert k in extra_args, "arg %s is not in extra_args" % k + assert k in shared_data_array2,\ + "arg %s of the shared executor group not in the shared_data_array of the created executor group" % k + assert mx.test_utils.same_array(v, shared_data_array2[k]) + for data_name, array in exec_grp_shared.shared_data_arrays[i].items(): assert data_name in exec_grp_created.shared_data_arrays[i], \ "Shared input data '%s' is not in " \ "shared_data_arrays of created executor group." % (data_name) assert mx.test_utils.same_array(array, exec_grp_created.shared_data_arrays[i][data_name]), \ "Shared input data '%s' does not share memory." % (data_name) - for input_name in extra_input: - assert input_name in exec_grp_created.execs[i].arg_dict, \ - "Extra input data '%s' is not in arg_dict of created executor group." % (input_name) - # Test shared argument arrays and gradient arrays - for i in range(len(exec_grp_shared.execs)): - exec1 = exec_grp_shared.execs[i] - exec2 = exec_grp_created.execs[i] - for arg_name in shared_arg_names: - assert arg_name in exec2.arg_dict, \ - "Shared argument '%s' is not in arg_dict of created executor group." % (arg_name) - assert mx.test_utils.same_array(exec1.arg_dict[arg_name], exec2.arg_dict[arg_name]), \ - "Shared argument '%s' does not share memory." % (arg_name) - for arg_name in extra_arg: - assert arg_name in exec2.arg_dict, \ - "Extra argument '%s' is not in arg_dict of created executor group." % (arg_name) + # Test shared argument arrays and gradient arrays + exec_shared = exec_grp_shared.execs[i] + exec_created = exec_grp_created.execs[i] + if shared_arg_names is not None: + # test shared arguments + for arg_name in shared_arg_names: + assert arg_name in exec_created.arg_dict, \ + "Shared argument '%s' is not in arg_dict of created executor group." % (arg_name) + assert mx.test_utils.same_array(exec_shared.arg_dict[arg_name], exec_created.arg_dict[arg_name]), \ + "Shared argument '%s' does not share memory." % (arg_name) + # test shared argument gradients + for arg_name in shared_arg_names: + assert arg_name in exec_created.grad_dict, \ + "Shared argument gradient '%s' is not in " \ + "grad_dict of created executor group." % (arg_name) + assert mx.test_utils.same_array(exec_shared.grad_dict[arg_name], exec_created.grad_dict[arg_name]), \ + "Shared argument gradient '%s' does not sharing memory." % (arg_name) + for arg_name, grad in exec_grp_shared.grad_req.items(): assert grad == exec_grp_created.grad_req[arg_name], \ "Gradient requirements for shared argument '%s' are inconsistent. " \ "Shared executor group requires '%s' while created executor group requires '%s'" \ %(arg_name, grad, exec_grp_created.grad_req[arg_name]) - for arg_name in shared_arg_names: - assert arg_name in exec2.grad_dict, \ - "Shared argument gradient '%s' is not in " \ - "grad_dict of created executor group." % (arg_name) - assert mx.test_utils.same_array(exec1.grad_dict[arg_name], exec2.grad_dict[arg_name]), \ - "Shared argument gradient '%s' does not sharing memory." % (arg_name) contexts = [mx.cpu(0), mx.cpu(1)] workload = [1] * len(contexts) - batch_size = 16 - num_hidden = 4 - data_shapes1 = [('data1', (batch_size, 10))] - data_shapes2 = [('data1', (batch_size, 10)), ('data2', (batch_size, 10))] - label_shapes = [('softmax_label', (batch_size,))] - - data1 = mx.sym.Variable('data1') - data2 = mx.sym.Variable('data2') - fc1 = mx.sym.FullyConnected(data=data1, name='fc1', num_hidden=num_hidden) - mlp1 = mx.sym.SoftmaxOutput(data=fc1, name='softmax') - fc1 = mx.sym.FullyConnected(data=data1 + data2, name='fc1', num_hidden=num_hidden) - fc2 = mx.sym.FullyConnected(data=fc1, name='fc2', num_hidden=num_hidden) - mlp2 = mx.sym.SoftmaxOutput(data=fc2, name='softmax') - - arg_names = mlp1.list_arguments() - input_names = [name[0] for name in data_shapes1] + [name[0] for name in label_shapes] - shared_arg_names = [name for name in arg_names if name not in input_names] - - exec_group1 = DataParallelExecutorGroup(symbol=mlp1, contexts=contexts, - workload=workload, data_shapes=data_shapes1, + batch_size = 32 + max_bucket_size = 80 + num_words = 1000 + num_hidden = 100 + num_embed = 200 + data_shapes = [('data', (batch_size, max_bucket_size))] + label_shapes = [('softmax_label', (batch_size, max_bucket_size))] + + # generate an rnn sym with #layers=5 + sym = get_rnn_sym(num_layers=3, num_words=num_words, num_hidden=num_hidden, + num_embed=num_embed, seq_len=max_bucket_size) + arg_names1 = sym.list_arguments() + input_names = [name[0] for name in data_shapes] + [name[0] for name in label_shapes] + shared_arg_names = [name for name in arg_names1 if name not in input_names] + exec_group1 = DataParallelExecutorGroup(symbol=sym, contexts=contexts, + workload=workload, data_shapes=data_shapes, label_shapes=label_shapes, param_names=shared_arg_names, for_training=True, inputs_need_grad=False) - # Test two executor groups with the same symbol sharing memory - exec_group2 = DataParallelExecutorGroup(symbol=mlp1, contexts=contexts, - workload=workload, data_shapes=data_shapes1, + # shared_data_arrays should only have input "data" and "softmax_label" arrays + for i in range(len(contexts)): + assert len(exec_group1.shared_data_arrays[i]) == len(input_names),\ + "exec_group1.shared_data_arrays[%d] should have the same number of names as in input_names" % i + for name in input_names: + assert name in exec_group1.shared_data_arrays[i],\ + "arg %s should be in exec_group1.shared_data_arrays[%d]" % (name, i) + + # generate an rnn sym with #layers=5 + sym = get_rnn_sym(num_layers=5, num_words=num_words, num_hidden=num_hidden, + num_embed=num_embed, seq_len=max_bucket_size) + arg_names2 = sym.list_arguments() + exec_group2 = DataParallelExecutorGroup(symbol=sym, contexts=contexts, + workload=workload, data_shapes=data_shapes, label_shapes=label_shapes, param_names=shared_arg_names, for_training=True, inputs_need_grad=False, shared_group=exec_group1) - test_create_exec_group(exec_group1, exec_group2, shared_arg_names) + extra_args = [name for name in arg_names2 if name not in shared_arg_names] + test_shared_exec_group(exec_grp_shared=exec_group1, exec_grp_created=exec_group2, + shared_arg_names=shared_arg_names, extra_args=extra_args) - # Test two executor groups with different symbol sharing memory - exec_group3 = DataParallelExecutorGroup(symbol=mlp2, contexts=contexts, - workload=workload, data_shapes=data_shapes2, - label_shapes=label_shapes, param_names=shared_arg_names, - for_training=True, inputs_need_grad=False, - shared_group=exec_group1) - extra_input = ['data2'] - extra_arg = ['fc2_weight', 'fc2_bias'] - test_create_exec_group(exec_group1, exec_group3, shared_arg_names, extra_input, extra_arg) if __name__ == '__main__': test_module_dtype() From d7e8366f40934f65c9bca089ef1b892dc30a783f Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 25 May 2017 20:44:03 -0700 Subject: [PATCH 06/14] Fix lint --- python/mxnet/module/executor_group.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py index 35512bb7c60e..ce71fa3ad4e9 100755 --- a/python/mxnet/module/executor_group.py +++ b/python/mxnet/module/executor_group.py @@ -4,7 +4,6 @@ import logging from collections import OrderedDict -import numpy as np from .. import context as ctx from .. import ndarray as nd From aae92f6e122ec0287f1fd075629cec5d122a876c Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Tue, 30 May 2017 11:21:26 -0700 Subject: [PATCH 07/14] Simple bind (#3) * Add bucketing test * Skip pylint * Use cpu to train --- tests/python/train/test_bucketing.py | 107 +++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 tests/python/train/test_bucketing.py diff --git a/tests/python/train/test_bucketing.py b/tests/python/train/test_bucketing.py new file mode 100644 index 000000000000..047870b6b79e --- /dev/null +++ b/tests/python/train/test_bucketing.py @@ -0,0 +1,107 @@ +# pylint: skip-file +import numpy as np +import mxnet as mx +import random +from random import randint + +def test_bucket_module(): + import logging + head = '%(asctime)-15s %(message)s' + logging.basicConfig(level=logging.DEBUG, format=head) + + class DummySentenceIter(mx.rnn.BucketSentenceIter): + """Dummy sentence iterator to output sentences the same as input. + """ + def __init__(self, sentences, batch_size, buckets=None, invalid_label=-1, + data_name='data', label_name='l2_label', dtype='float32', + layout='NTC'): + super(DummySentenceIter, self).__init__(sentences, batch_size, + buckets=buckets, invalid_label=invalid_label, + data_name=data_name, label_name=label_name, + dtype=dtype, layout=layout) + + def reset(self): + """Resets the iterator to the beginning of the data.""" + self.curr_idx = 0 + random.shuffle(self.idx) + for buck in self.data: + np.random.shuffle(buck) + + self.nddata = [] + self.ndlabel = [] + for buck in self.data: + self.nddata.append(mx.nd.array(buck, dtype=self.dtype)) + self.ndlabel.append(mx.nd.array(buck, dtype=self.dtype)) + + batch_size = 128 + num_epochs = 20 + num_hidden = 50 + num_embed = 50 + num_layers = 2 + len_vocab = 100 + buckets = [10, 20, 30, 40, 50, 60] + + invalid_label = 0 + num_sentence = 2500 + + train_sent = [] + val_sent = [] + + for _ in range(num_sentence): + len_sentence = randint(1, max(buckets) + 10) + train_sentence = [] + val_sentence = [] + for _ in range(len_sentence): + train_sentence.append(randint(1, len_vocab)) + val_sentence.append(randint(1, len_vocab)) + train_sent.append(train_sentence) + val_sent.append(val_sentence) + + data_train = DummySentenceIter(train_sent, batch_size, buckets=buckets, + invalid_label=invalid_label) + data_val = DummySentenceIter(val_sent, batch_size, buckets=buckets, + invalid_label=invalid_label) + + stack = mx.rnn.SequentialRNNCell() + for i in range(num_layers): + stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_'%i)) + + def sym_gen(seq_len): + data = mx.sym.Variable('data') + label = mx.sym.Variable('l2_label') + embed = mx.sym.Embedding(data=data, input_dim=len_vocab, + output_dim=num_embed, name='embed') + + stack.reset() + outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True) + + pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden)) + pred = mx.sym.FullyConnected(data=pred, num_hidden=1, name='pred') + pred = mx.sym.reshape(pred, shape= (batch_size, -1)) + loss = mx.sym.LinearRegressionOutput(pred, label, name='l2_loss') + + return loss, ('data',), ('l2_label',) + + contexts = mx.cpu(0) + + model = mx.mod.BucketingModule( + sym_gen = sym_gen, + default_bucket_key = data_train.default_bucket_key, + context = contexts) + + model.fit( + train_data = data_train, + eval_data = data_val, + eval_metric = mx.metric.MSE(), + kvstore = 'device', + optimizer = 'sgd', + optimizer_params = { 'learning_rate': 0.01, + 'momentum': 0, + 'wd': 0.00001 }, + initializer = mx.init.Xavier(factor_type="in", magnitude=2.34), + num_epoch = num_epochs, + batch_end_callback = mx.callback.Speedometer(batch_size, 50)) + assert model.score(data_val, mx.metric.MSE())[0][1] < 15, "High mean square error." + +if __name__ == "__main__": + test_bucket_module() From 9058ae78f90f7799729a05dfd3caa9ec90184b42 Mon Sep 17 00:00:00 2001 From: reminisce Date: Tue, 30 May 2017 13:49:53 -0700 Subject: [PATCH 08/14] Fix bug --- src/c_api/c_api_executor.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index 87085befef3d..ca49402ecf7e 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -229,7 +229,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, // attr_dict for setting up type_dict and arg/aux ctx std::unordered_map> attr_dict; - if (nullptr == provided_arg_dtypes || nullptr == g2c_keys) { + if (nullptr == provided_arg_dtypes || nullptr != g2c_keys) { std::vector> attrs = sym->ListAttrsRecursive(); attr_dict.reserve(attrs.size()); From 722c9e5887158d0ec0b15bf92bfbc9ef518bba98 Mon Sep 17 00:00:00 2001 From: reminisce Date: Tue, 30 May 2017 20:36:36 -0700 Subject: [PATCH 09/14] Remove merge message --- python/mxnet/symbol.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index c6973b133ab4..1aeb2d54dce3 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -1789,13 +1789,6 @@ def load_json(json_str): return Symbol(handle) -<<<<<<< HEAD -======= -# Initialize the atomic symbol in startups -_init_symbol_module(Symbol, "mxnet") - - ->>>>>>> Initial checkin # pylint: disable=no-member # pylint: disable=redefined-builtin def pow(base, exp): From be02b285e32cd8bf619766b128aa577886a0a88f Mon Sep 17 00:00:00 2001 From: reminisce Date: Tue, 30 May 2017 20:52:20 -0700 Subject: [PATCH 10/14] Fix lint --- python/mxnet/symbol.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 1aeb2d54dce3..0a26afec4731 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -13,11 +13,11 @@ import numpy as _numpy from .base import _LIB, numeric_types -from .base import c_array, c_str, mx_uint, py_str, string_types, mx_real_t +from .base import c_array, c_str, mx_uint, py_str, string_types from .base import NDArrayHandle, ExecutorHandle, SymbolHandle, OpHandle -from .base import check_call, MXNetError, _Null # pylint: disable=unused-import +from .base import check_call, MXNetError, _Null # pylint: disable=unused-import from .context import Context, cpu -from .ndarray import NDArray, zeros as _nd_zeros, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP +from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP from .name import NameManager # pylint: disable=unused-import from .executor import Executor from . import _symbol_internal as _internal From 1236d7ed5f20dea0e874605679f55be8fd8d968a Mon Sep 17 00:00:00 2001 From: reminisce Date: Wed, 31 May 2017 11:10:26 -0700 Subject: [PATCH 11/14] Add logging to test_bucketing.py --- tests/python/train/test_bucketing.py | 50 ++++++++++++++++------------ 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/tests/python/train/test_bucketing.py b/tests/python/train/test_bucketing.py index 047870b6b79e..e4c834176a82 100644 --- a/tests/python/train/test_bucketing.py +++ b/tests/python/train/test_bucketing.py @@ -4,14 +4,19 @@ import random from random import randint + def test_bucket_module(): import logging head = '%(asctime)-15s %(message)s' logging.basicConfig(level=logging.DEBUG, format=head) + console = logging.StreamHandler() + console.setLevel(logging.DEBUG) + logging.getLogger('').addHandler(console) class DummySentenceIter(mx.rnn.BucketSentenceIter): """Dummy sentence iterator to output sentences the same as input. """ + def __init__(self, sentences, batch_size, buckets=None, invalid_label=-1, data_name='data', label_name='l2_label', dtype='float32', layout='NTC'): @@ -40,7 +45,7 @@ def reset(self): num_layers = 2 len_vocab = 100 buckets = [10, 20, 30, 40, 50, 60] - + invalid_label = 0 num_sentence = 2500 @@ -57,14 +62,14 @@ def reset(self): train_sent.append(train_sentence) val_sent.append(val_sentence) - data_train = DummySentenceIter(train_sent, batch_size, buckets=buckets, - invalid_label=invalid_label) - data_val = DummySentenceIter(val_sent, batch_size, buckets=buckets, - invalid_label=invalid_label) + data_train = DummySentenceIter(train_sent, batch_size, buckets=buckets, + invalid_label=invalid_label) + data_val = DummySentenceIter(val_sent, batch_size, buckets=buckets, + invalid_label=invalid_label) stack = mx.rnn.SequentialRNNCell() for i in range(num_layers): - stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_'%i)) + stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_' % i)) def sym_gen(seq_len): data = mx.sym.Variable('data') @@ -77,7 +82,7 @@ def sym_gen(seq_len): pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden)) pred = mx.sym.FullyConnected(data=pred, num_hidden=1, name='pred') - pred = mx.sym.reshape(pred, shape= (batch_size, -1)) + pred = mx.sym.reshape(pred, shape=(batch_size, -1)) loss = mx.sym.LinearRegressionOutput(pred, label, name='l2_loss') return loss, ('data',), ('l2_label',) @@ -85,23 +90,26 @@ def sym_gen(seq_len): contexts = mx.cpu(0) model = mx.mod.BucketingModule( - sym_gen = sym_gen, - default_bucket_key = data_train.default_bucket_key, - context = contexts) + sym_gen=sym_gen, + default_bucket_key=data_train.default_bucket_key, + context=contexts) + logging.info('Begin fit...') model.fit( - train_data = data_train, - eval_data = data_val, - eval_metric = mx.metric.MSE(), - kvstore = 'device', - optimizer = 'sgd', - optimizer_params = { 'learning_rate': 0.01, - 'momentum': 0, - 'wd': 0.00001 }, - initializer = mx.init.Xavier(factor_type="in", magnitude=2.34), - num_epoch = num_epochs, - batch_end_callback = mx.callback.Speedometer(batch_size, 50)) + train_data=data_train, + eval_data=data_val, + eval_metric=mx.metric.MSE(), + kvstore='device', + optimizer='sgd', + optimizer_params={'learning_rate': 0.01, + 'momentum': 0, + 'wd': 0.00001}, + initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), + num_epoch=num_epochs, + batch_end_callback=mx.callback.Speedometer(batch_size, 50)) + logging.info('Finished fit...') assert model.score(data_val, mx.metric.MSE())[0][1] < 15, "High mean square error." + if __name__ == "__main__": test_bucket_module() From 3a43edf4d0212038b883363245c308edd6bd9fe5 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Wed, 31 May 2017 17:32:49 -0700 Subject: [PATCH 12/14] Reduce model size (#4) --- tests/python/train/test_bucketing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/python/train/test_bucketing.py b/tests/python/train/test_bucketing.py index e4c834176a82..85ea107c5ca2 100644 --- a/tests/python/train/test_bucketing.py +++ b/tests/python/train/test_bucketing.py @@ -39,15 +39,15 @@ def reset(self): self.ndlabel.append(mx.nd.array(buck, dtype=self.dtype)) batch_size = 128 - num_epochs = 20 - num_hidden = 50 - num_embed = 50 + num_epochs = 5 + num_hidden = 25 + num_embed = 25 num_layers = 2 - len_vocab = 100 - buckets = [10, 20, 30, 40, 50, 60] + len_vocab = 50 + buckets = [10, 20, 30, 40] invalid_label = 0 - num_sentence = 2500 + num_sentence = 1000 train_sent = [] val_sent = [] @@ -108,7 +108,7 @@ def sym_gen(seq_len): num_epoch=num_epochs, batch_end_callback=mx.callback.Speedometer(batch_size, 50)) logging.info('Finished fit...') - assert model.score(data_val, mx.metric.MSE())[0][1] < 15, "High mean square error." + assert model.score(data_val, mx.metric.MSE())[0][1] < 350, "High mean square error." if __name__ == "__main__": From 1115d19c5c85deb7fc328b28fc4612d79073d742 Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 1 Jun 2017 20:32:22 -0700 Subject: [PATCH 13/14] Add checks for shape/type inferences --- src/executor/graph_executor.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 64e2410bedb2..20135d33dee6 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -425,7 +425,13 @@ void GraphExecutor::Init(nnvm::Symbol symbol, arg_dtypes.resize(idx.input_nodes().size(), -1); // Infer shapes and dtypes g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); + CHECK_EQ(g.GetAttr("shape_num_unknown_nodes"), 0) + << "Shape inference failed in bind. Please provide" + " sufficient shapes to make inference for the symbol"; g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); + CHECK_EQ(g.GetAttr("dtype_num_unknown_nodes"), 0) + << "Type inference failed in bind. Please provide" + " sufficcient types to make inference for the symbol"; // Initialize the rest attributes of the graph. // This function can be called by regular bind @@ -738,7 +744,13 @@ void GraphExecutor::Init(nnvm::Symbol symbol, } } g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); + CHECK_EQ(g.GetAttr("shape_num_unknown_nodes"), 0) + << "Shape inference failed in simple_bind. Please provide" + " sufficient shapes to make inference for the symbol"; g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); + CHECK_EQ(g.GetAttr("dtype_num_unknown_nodes"), 0) + << "Type inference failed in simple_bind. Please provide" + " sufficcient types to make inference for the symbol"; // Create in_args, arg_grads, and aux_states using // the inferred shapes and dtypes. From d3c1d5f9859f82f5f313046c70f8414c903ed1a2 Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 2 Jun 2017 13:29:58 -0700 Subject: [PATCH 14/14] Add printing error messages for shape/type inference failure --- python/mxnet/symbol.py | 68 ++++++++++++++------------ src/executor/graph_executor.cc | 87 +++++++++++++++++++++++++++------- 2 files changed, 106 insertions(+), 49 deletions(-) diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 0a26afec4731..d1f52b4b48f5 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -1336,37 +1336,43 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, num_aux_states = ctypes.c_uint() aux_state_handles = ctypes.POINTER(NDArrayHandle)() - check_call(_LIB.MXExecutorSimpleBind(self.handle, - ctypes.c_int(ctx.device_typeid), - ctypes.c_int(ctx.device_id), - num_ctx_map_keys, - ctx_map_keys, - ctx_map_dev_types, - ctx_map_dev_ids, - mx_uint(provided_req_type_list_len), - provided_grad_req_names, - provided_grad_req_types, - mx_uint(len(provided_arg_shape_names)), - c_array(ctypes.c_char_p, provided_arg_shape_names), - c_array(mx_uint, provided_arg_shape_data), - c_array(mx_uint, provided_arg_shape_idx), - num_provided_arg_types, - provided_arg_type_names, - provided_arg_type_data, - mx_uint(len(shared_arg_name_list)), - c_array(ctypes.c_char_p, shared_arg_name_list), - ctypes.byref(shared_buffer_len), - shared_buffer_names, - shared_buffer_handles, - ctypes.byref(updated_shared_buffer_names), - ctypes.byref(updated_shared_buffer_handles), - ctypes.byref(num_in_args), - ctypes.byref(in_arg_handles), - ctypes.byref(arg_grad_handles), - ctypes.byref(num_aux_states), - ctypes.byref(aux_state_handles), - shared_exec_handle, - ctypes.byref(exe_handle))) + try: + check_call(_LIB.MXExecutorSimpleBind(self.handle, + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + num_ctx_map_keys, + ctx_map_keys, + ctx_map_dev_types, + ctx_map_dev_ids, + mx_uint(provided_req_type_list_len), + provided_grad_req_names, + provided_grad_req_types, + mx_uint(len(provided_arg_shape_names)), + c_array(ctypes.c_char_p, provided_arg_shape_names), + c_array(mx_uint, provided_arg_shape_data), + c_array(mx_uint, provided_arg_shape_idx), + num_provided_arg_types, + provided_arg_type_names, + provided_arg_type_data, + mx_uint(len(shared_arg_name_list)), + c_array(ctypes.c_char_p, shared_arg_name_list), + ctypes.byref(shared_buffer_len), + shared_buffer_names, + shared_buffer_handles, + ctypes.byref(updated_shared_buffer_names), + ctypes.byref(updated_shared_buffer_handles), + ctypes.byref(num_in_args), + ctypes.byref(in_arg_handles), + ctypes.byref(arg_grad_handles), + ctypes.byref(num_aux_states), + ctypes.byref(aux_state_handles), + shared_exec_handle, + ctypes.byref(exe_handle))) + except MXNetError: + print("simple_bind error. Arguments:") + for k, v in kwargs.items(): + print(" %s: %s" % (k, v)) + raise RuntimeError('simple_bind failed') # update shared_buffer if shared_buffer is not None: diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 20135d33dee6..b41d1734d946 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -359,6 +359,53 @@ Graph AssignContext(Graph g, return g; } +void HandleInferShapeError(const size_t num_forward_inputs, + const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes) { + int cnt = 10; + std::ostringstream oss; + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const TShape& inferred_shape = inferred_shapes[eid]; + if (inferred_shape.ndim() == 0 || inferred_shape.Size() == 0U) { + const std::string& arg_name = idx[nid].source->attrs.name; + oss << arg_name << ": " << inferred_shape << ", "; + if (--cnt == 0) { + oss << "..."; + break; + } + } + } + LOG(FATAL) << "InferShape pass cannot decide shapes for the following arguments " + "(0s in shapes mean unknown dimension size). Please consider " + "providing them as inputs:\n" + << oss.str(); +} + +void HandleInferTypeError(const size_t num_forward_inputs, + const nnvm::IndexedGraph& idx, + const nnvm::DTypeVector& inferred_dtypes) { + int cnt = 10; + std::ostringstream oss; + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const int inferred_dtype = inferred_dtypes[eid]; + if (inferred_dtype == -1) { + const std::string& arg_name = idx[nid].source->attrs.name; + oss << arg_name << ": " << inferred_dtype << ", "; + if (--cnt == 0) { + oss << "..."; + break; + } + } + } + LOG(FATAL) << "InferType pass cannot decide dtypes for the following arguments " + "(-1 means unknown dtype). Please consider providing them as inputs:\n" + << oss.str(); +} + /*! * \brief GraphExecutor initializer for regular bind flow in which * input arguments and gradients are provided by users. This initializer @@ -391,7 +438,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, // create arg_shapes and arg_dtypes for shape and type inferences const auto& idx = g.indexed_graph(); - auto mutable_nodes = idx.mutable_input_nodes(); + const auto& mutable_nodes = idx.mutable_input_nodes(); size_t arg_top = 0, aux_top = 0; data_entry_.resize(idx.num_node_entries()); nnvm::ShapeVector arg_shapes; @@ -422,16 +469,18 @@ void GraphExecutor::Init(nnvm::Symbol symbol, // expand arg_shapes and arg_dtypes to contain backward inputs arg_shapes.resize(idx.input_nodes().size(), TShape()); - arg_dtypes.resize(idx.input_nodes().size(), -1); - // Infer shapes and dtypes g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); - CHECK_EQ(g.GetAttr("shape_num_unknown_nodes"), 0) - << "Shape inference failed in bind. Please provide" - " sufficient shapes to make inference for the symbol"; + if (g.GetAttr("shape_num_unknown_nodes") != 0U) { + HandleInferShapeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("shape")); + } + + arg_dtypes.resize(idx.input_nodes().size(), -1); g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); - CHECK_EQ(g.GetAttr("dtype_num_unknown_nodes"), 0) - << "Type inference failed in bind. Please provide" - " sufficcient types to make inference for the symbol"; + if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { + HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("dtype")); + } // Initialize the rest attributes of the graph. // This function can be called by regular bind @@ -459,8 +508,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, // populate grad_store_ data_entry_.resize(idx.num_node_entries()); size_t arg_top = 0, aux_top = 0; - auto mutable_nodes = idx.mutable_input_nodes(); - // TODO(junwu): populate in_arg_map, arg_grad_map, and aux_state_map + const auto& mutable_nodes = idx.mutable_input_nodes(); for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const uint32_t eid = idx.entry_id(nid, 0); @@ -545,7 +593,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, // initialize in_args, arg_grads, and aux_states and populate grad_store_ data_entry_.resize(idx.num_node_entries()); size_t arg_top = 0, aux_top = 0; - auto mutable_nodes = idx.mutable_input_nodes(); + const auto& mutable_nodes = idx.mutable_input_nodes(); for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const uint32_t eid = idx.entry_id(nid, 0); @@ -744,13 +792,16 @@ void GraphExecutor::Init(nnvm::Symbol symbol, } } g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); - CHECK_EQ(g.GetAttr("shape_num_unknown_nodes"), 0) - << "Shape inference failed in simple_bind. Please provide" - " sufficient shapes to make inference for the symbol"; + if (g.GetAttr("shape_num_unknown_nodes") != 0U) { + HandleInferShapeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("shape")); + } + g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); - CHECK_EQ(g.GetAttr("dtype_num_unknown_nodes"), 0) - << "Type inference failed in simple_bind. Please provide" - " sufficcient types to make inference for the symbol"; + if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { + HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("dtype")); + } // Create in_args, arg_grads, and aux_states using // the inferred shapes and dtypes.