Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Move more front-end work to backend
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed May 11, 2017
1 parent 99dd2e7 commit fa28df0
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 224 deletions.
20 changes: 6 additions & 14 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1086,14 +1086,10 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
const char** g2c_keys,
const int* g2c_dev_types,
const int* g2c_dev_ids,
const mx_uint in_arg_len,
const int* in_arg_dev_types,
const int* in_arg_dev_ids,
const mx_uint* grad_req_types,
const mx_uint aux_state_len,
const int* aux_state_dev_types,
const int* aux_state_dev_ids,
const mx_uint num_provided_args,
const mx_uint provided_grad_req_list_len,
const char** provided_grad_req_names,
const char** provided_grad_req_types,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
Expand All @@ -1105,14 +1101,10 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
mx_uint* num_shared_data_arrays,
const char*** shared_data_array_name_list,
NDArrayHandle** shared_data_array_handle_list,
const mx_uint num_shared_exec_in_args,
NDArrayHandle* shared_exec_in_arg_handles,
const mx_uint num_shared_exec_arg_grads,
NDArrayHandle* shared_exec_arg_grad_handles,
const mx_uint num_shared_exec_aux_states,
NDArrayHandle* shared_exec_aux_state_handles,
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
mx_uint* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);
Expand Down
21 changes: 18 additions & 3 deletions include/mxnet/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ class Executor {
* \return array of outputs in the executor.
*/
virtual const std::vector<NDArray> &outputs() const = 0;
/*!
* \brief get input argument map, key is arg name, value is arg's NDArray.
* \return input argument map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& in_arg_map() const = 0;
/*!
* \brief get input argument graident map, key is arg name, value is gradient's NDArray.
* \return input argument gradient map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& arg_grad_map() const = 0;
/*!
* \brief get aux state map, key is arg name, value is aux state's NDArray.
* \return aux state map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& aux_state_map() const = 0;
/*!
* \brief Create an operator by bind symbol with context and arguments.
* If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp.
Expand Down Expand Up @@ -102,9 +117,9 @@ class Executor {
const std::unordered_map<std::string, int>& arg_dtype_map,
const std::vector<OpReqType>& grad_req_types,
const std::unordered_set<std::string>& param_names,
const std::vector<NDArray>& shared_exec_in_args,
const std::vector<NDArray>& shared_exec_arg_grads,
const std::vector<NDArray>& shared_exec_aux_states,
//const std::vector<NDArray>& shared_exec_in_args,
//const std::vector<NDArray>& shared_exec_arg_grads,
//const std::vector<NDArray>& shared_exec_aux_states,
std::vector<NDArray>* in_args,
std::vector<NDArray>* arg_grads,
std::vector<NDArray>* aux_states,
Expand Down
195 changes: 96 additions & 99 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,78 +1166,99 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
executor : mxnet.Executor
The generated executor
"""
listed_arguments = self.list_arguments() # read-only args
listed_aux_states = self.list_auxiliary_states() # aux states

attr_dict = None
if type_dict is None:
attr_dict = self.attr_dict()
type_dict = {k: mx_real_t for k in listed_arguments
if k not in attr_dict or '__dtype__' not in attr_dict[k]}

provided_arg_type_names = [] # provided type argument names
provided_arg_type_data = [] # provided types
for k, v in type_dict.items():
v = _numpy.dtype(v).type
if v in _DTYPE_NP_TO_MX:
provided_arg_type_names.append(k)
provided_arg_type_data.append(_DTYPE_NP_TO_MX[v])
# listed_arguments = self.list_arguments() # read-only args
# listed_aux_states = self.list_auxiliary_states() # aux states

# attr_dict = None
# if type_dict is None:
# attr_dict = self.attr_dict()
# type_dict = {k: mx_real_t for k in listed_arguments
# if k not in attr_dict or '__dtype__' not in attr_dict[k]}

num_provided_arg_types = 0
provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names
provided_arg_type_data = ctypes.POINTER(mx_uint)() # provided types
if type_dict is not None:
provided_arg_type_names = []
provided_arg_type_data = []
for k, v in type_dict.items():
v = _numpy.dtype(v).type
if v in _DTYPE_NP_TO_MX:
provided_arg_type_names.append(c_str(k))
provided_arg_type_data.append(ctypes.c_int(_DTYPE_NP_TO_MX[v]))
num_provided_arg_types = mx_uint(len(provided_arg_type_names))
provided_arg_type_names = c_array(ctypes.c_char_p, provided_arg_type_names)
provided_arg_type_data = c_array(ctypes.c_int, provided_arg_type_data)

provided_arg_shape_data = [] # shape data
# argument shape index in sdata, e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg
provided_arg_shape_idx = [0]
provided_arg_shape_names = [] # provided argument names
for k, v in kwargs.items():
if k not in listed_arguments and k not in listed_aux_states:
raise ValueError('arg name %s is not valid', k)
# if k not in listed_arguments and k not in listed_aux_states:
# raise ValueError('arg name %s is not valid', k)
if isinstance(v, tuple):
provided_arg_shape_names.append(c_str(k))
provided_arg_shape_data.extend(v)
provided_arg_shape_idx.append(len(provided_arg_shape_data))

req_map = {'null': 0, 'write': 1, 'add': 3}
if isinstance(grad_req, string_types):
if grad_req not in req_map:
raise ValueError('grad_req=%s is not in %s' % grad_req, str(req_map))
grad_req_types = [mx_uint(req_map[grad_req])] * len(listed_arguments)
elif isinstance(grad_req, list):
grad_req_types = [mx_uint(req_map[item]) for item in grad_req]
elif isinstance(grad_req, dict):
grad_req_types = []
for name in listed_arguments:
if name in grad_req:
grad_req_types.append(mx_uint(req_map[grad_req[name]]))
else:
grad_req_types.append(mx_uint(0))

if group2ctx is not None:
if attr_dict is None:
attr_dict = self.attr_dict()
arg_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx)
if name in attr_dict and '__ctx_group__' in attr_dict[name]
else ctx for name in listed_arguments]
aux_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx)
if name in attr_dict and '__ctx_group__' in attr_dict[name]
else ctx for name in listed_aux_states]
else:
arg_ctx = [ctx] * len(listed_arguments)
aux_ctx = [ctx] * len(listed_aux_states)

ctx_map_keys = []
ctx_map_dev_types = []
ctx_map_dev_ids = []
provided_req_type_list_len = 0
provided_grad_req_types = ctypes.POINTER(ctypes.c_char_p)()
provided_grad_req_names = ctypes.POINTER(ctypes.c_char_p)()
if grad_req is not None:
if isinstance(grad_req, string_types):
# use provided_req_type_list_len = 0 to indicate this situation
provided_req_type_list_len = 0
provided_grad_req_types = [c_str(grad_req)]
elif isinstance(grad_req, list):
provided_grad_req_types = [c_str(item) for item in grad_req]
provided_req_type_list_len = len(provided_grad_req_types)
elif isinstance(grad_req, dict):
provided_grad_req_names = []
provided_grad_req_types = []
for k, v in grad_req.items():
provided_grad_req_names.append(c_str(k))
provided_grad_req_types.append(c_str(v))
provided_grad_req_names = c_array(ctypes.c_char_p, provided_grad_req_names)
provided_req_type_list_len = len(provided_grad_req_types)
provided_grad_req_types = c_array(ctypes.c_char_p, provided_grad_req_types)

# if group2ctx is not None:
# if attr_dict is None:
# attr_dict = self.attr_dict()
# arg_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx)
# if name in attr_dict and '__ctx_group__' in attr_dict[name]
# else ctx for name in listed_arguments]
# aux_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx)
# if name in attr_dict and '__ctx_group__' in attr_dict[name]
# else ctx for name in listed_aux_states]
# else:
# arg_ctx = [ctx] * len(listed_arguments)
# aux_ctx = [ctx] * len(listed_aux_states)

num_ctx_map_keys = mx_uint(0)
ctx_map_keys = ctypes.POINTER(ctypes.c_char_p)()
ctx_map_dev_types = ctypes.POINTER(ctypes.c_int)()
ctx_map_dev_ids = ctypes.POINTER(ctypes.c_int)()
if group2ctx is not None:
ctx_map_keys = []
ctx_map_dev_types = []
ctx_map_dev_ids = []
for key, val in group2ctx.items():
ctx_map_keys.append(c_str(key))
ctx_map_dev_types.append(ctypes.c_int(val.device_typeid))
ctx_map_dev_ids.append(ctypes.c_int(val.device_id))
num_ctx_map_keys = mx_uint(len(ctx_map_keys))
ctx_map_keys = c_array(ctypes.c_char_p, ctx_map_keys)
ctx_map_dev_types = c_array(ctypes.c_int, ctx_map_dev_types)
ctx_map_dev_ids = c_array(ctypes.c_int, ctx_map_dev_ids)

# prepare param names
if param_names is None:
param_names = []
else:
param_name_list = []
if param_names is not None:
if not isinstance(param_names, list):
raise ValueError('param_names in simple_bind must be a list or None')
param_name_list = [c_str(name) for name in param_names]

# prepare shared_data_arrays
if shared_data_arrays is None:
Expand All @@ -1250,27 +1271,11 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
shared_data_array_names = []
shared_data_array_handles = []
for k, v in shared_data_arrays.items():
shared_data_array_names.append(k)
shared_data_array_names.append(c_str(k))
shared_data_array_handles.append(v.handle)
shared_data_array_names = ctypes.POINTER(ctypes.c_char_p)(shared_data_array_names)
shared_data_array_names = c_array(ctypes.c_char_p, shared_data_array_names)
num_shared_data_arrays = mx_uint(len(shared_data_array_handles))

# prepare shared_exec_in_args
if shared_exec is None:
num_shared_exec_in_args = 0
shared_exec_in_arg_handles = ctypes.POINTER(NDArrayHandle)()
num_shared_exec_arg_grads = 0
shared_exec_arg_grad_handles = ctypes.POINTER(NDArrayHandle)()
num_shared_exec_aux_states = 0
shared_exec_aux_state_handles = ctypes.POINTER(NDArrayHandle)()
else:
shared_exec_in_arg_handles = [nd.handle for nd in shared_exec.arg_arrays]
num_shared_exec_in_args = len(shared_exec_in_arg_handles)
shared_exec_arg_grad_handles = [nd.handle if nd is not None
else None for nd in shared_exec.grad_arrays]
num_shared_exec_arg_grads = len(shared_exec_arg_grad_handles)
shared_exec_aux_state_handles = [nd.handle for nd in shared_exec.aux_arrays]
num_shared_exec_aux_states = len(shared_exec_aux_state_handles)
shared_data_array_handles = c_array(NDArrayHandle, shared_data_array_handles)

# prepare shared_exec_handle
shared_exec_handle = shared_exec.handle if shared_exec is not None else ExecutorHandle()
Expand All @@ -1279,46 +1284,38 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
exe_handle = ExecutorHandle()

# prepare current executor's in_args, arg_grads, and aux_states
num_in_args = ctypes.c_uint()
in_arg_handles = ctypes.POINTER(NDArrayHandle)()
arg_grad_handles = ctypes.POINTER(NDArrayHandle)()
num_aux_states = ctypes.c_uint()
aux_state_handles = ctypes.POINTER(NDArrayHandle)()

check_call(_LIB.MXExecutorSimpleBind(self.handle,
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
mx_uint(len(ctx_map_keys)),
c_array(ctypes.c_char_p, ctx_map_keys),
c_array(ctypes.c_int, ctx_map_dev_types),
c_array(ctypes.c_int, ctx_map_dev_ids),
mx_uint(len(listed_arguments)),
c_array(ctypes.c_int,
[in_arg_ctx.device_typeid for in_arg_ctx in arg_ctx]),
c_array(ctypes.c_int, [in_arg_ctx.device_id for in_arg_ctx in arg_ctx]),
c_array(mx_uint, grad_req_types),
mx_uint(len(listed_aux_states)),
c_array(ctypes.c_int,
[aux_state_ctx.device_typeid for aux_state_ctx in aux_ctx]),
c_array(ctypes.c_int,
[aux_state_ctx.device_id for aux_state_ctx in aux_ctx]),
num_ctx_map_keys,
ctx_map_keys,
ctx_map_dev_types,
ctx_map_dev_ids,
mx_uint(provided_req_type_list_len),
provided_grad_req_names,
provided_grad_req_types,
mx_uint(len(provided_arg_shape_names)),
c_array(ctypes.c_char_p, provided_arg_shape_names),
c_array(mx_uint, provided_arg_shape_data),
c_array(mx_uint, provided_arg_shape_idx),
len(provided_arg_type_names),
c_array(ctypes.c_char_p, provided_arg_type_names),
c_array(ctypes.c_int, provided_arg_type_data),
mx_uint(len(param_names)),
c_array(ctypes.c_char_p, param_names),
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
mx_uint(len(param_name_list)),
c_array(ctypes.c_char_p, param_name_list),
ctypes.byref(num_shared_data_arrays),
ctypes.byref(shared_data_array_names),
ctypes.byref(shared_data_array_handles),
mx_uint(num_shared_exec_in_args),
shared_exec_in_arg_handles,
mx_uint(num_shared_exec_arg_grads),
shared_exec_arg_grad_handles,
mx_uint(num_shared_exec_aux_states),
shared_exec_aux_state_handles,
ctypes.byref(num_in_args),
ctypes.byref(in_arg_handles),
ctypes.byref(arg_grad_handles),
ctypes.byref(num_aux_states),
ctypes.byref(aux_state_handles),
shared_exec_handle,
ctypes.byref(exe_handle)))
Expand All @@ -1333,11 +1330,11 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
shared_data_arrays[k] = v

# create in_args, arg_grads, and aux_states for the current executor
arg_arrays = [NDArray(NDArrayHandle(in_arg_handles[i])) for i in range(len(listed_arguments))]
arg_arrays = [NDArray(NDArrayHandle(in_arg_handles[i])) for i in range(num_in_args.value)]
grad_arrays = [NDArray(NDArrayHandle(arg_grad_handles[i]))
if arg_grad_handles[i] is not None
else None for i in range(len(listed_arguments))]
aux_arrays = [NDArray(NDArrayHandle(aux_state_handles[i])) for i in range(len(listed_aux_states))]
else None for i in range(num_in_args.value)]
aux_arrays = [NDArray(NDArrayHandle(aux_state_handles[i])) for i in range(num_aux_states.value)]

executor = Executor(exe_handle, self, ctx, grad_req, group2ctx)
executor.arg_arrays = arg_arrays
Expand Down
Loading

0 comments on commit fa28df0

Please sign in to comment.