From 35bf4d045f917d640e32877dd528be99d8e1a33b Mon Sep 17 00:00:00 2001 From: reminisce Date: Wed, 10 May 2017 21:56:18 -0700 Subject: [PATCH] Fix failed tests --- python/mxnet/symbol.py | 30 ++++++-------------------- src/c_api/c_api_executor.cc | 2 +- tests/python/unittest/test_executor.py | 2 +- 3 files changed, 8 insertions(+), 26 deletions(-) diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 26083dfdda71..70d275bf35a2 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -1125,7 +1125,7 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, ---------- >>> x = mx.sym.Variable('x') >>> y = mx.sym.FullyConnected(x, num_hidden=4) - >>> exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req=[]) + >>> exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req='null') >>> exe.forward() [] >>> exe.outputs[0].asnumpy() @@ -1166,15 +1166,6 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, executor : mxnet.Executor The generated executor """ - # listed_arguments = self.list_arguments() # read-only args - # listed_aux_states = self.list_auxiliary_states() # aux states - - # attr_dict = None - # if type_dict is None: - # attr_dict = self.attr_dict() - # type_dict = {k: mx_real_t for k in listed_arguments - # if k not in attr_dict or '__dtype__' not in attr_dict[k]} - num_provided_arg_types = 0 provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names provided_arg_type_data = ctypes.POINTER(mx_uint)() # provided types @@ -1211,9 +1202,13 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, provided_req_type_list_len = 0 provided_grad_req_types = [c_str(grad_req)] elif isinstance(grad_req, list): + if len(grad_req) == 0: + raise RuntimeError('grad_req in simple_bind cannot be an empty list') provided_grad_req_types = [c_str(item) for item in grad_req] provided_req_type_list_len = len(provided_grad_req_types) elif isinstance(grad_req, dict): + if len(grad_req) == 0: + raise RuntimeError('grad_req in simple_bind cannot be an empty dict') provided_grad_req_names = [] provided_grad_req_types = [] for k, v in grad_req.items(): @@ -1223,19 +1218,6 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, provided_req_type_list_len = len(provided_grad_req_types) provided_grad_req_types = c_array(ctypes.c_char_p, provided_grad_req_types) - # if group2ctx is not None: - # if attr_dict is None: - # attr_dict = self.attr_dict() - # arg_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx) - # if name in attr_dict and '__ctx_group__' in attr_dict[name] - # else ctx for name in listed_arguments] - # aux_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx) - # if name in attr_dict and '__ctx_group__' in attr_dict[name] - # else ctx for name in listed_aux_states] - # else: - # arg_ctx = [ctx] * len(listed_arguments) - # aux_ctx = [ctx] * len(listed_aux_states) - num_ctx_map_keys = mx_uint(0) ctx_map_keys = ctypes.POINTER(ctypes.c_char_p)() ctx_map_dev_types = ctypes.POINTER(ctypes.c_int)() @@ -1359,7 +1341,7 @@ def simple_bind_v1(self, ctx, ---------- >>> x = mx.sym.Variable('x') >>> y = mx.sym.FullyConnected(x, num_hidden=4) - >>> exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req=[]) + >>> exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req='null') >>> exe.forward() [] >>> exe.outputs[0].asnumpy() diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index 2bd4ca000736..fe71bf5db695 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -331,7 +331,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, } // initialize arg_grad_ctx_vec and grad_req_type_vec - std::vector arg_grad_ctx_vec(in_arg_names.size()); + std::vector arg_grad_ctx_vec(in_arg_names.size(), ctx); std::vector grad_req_type_vec(in_arg_names.size(), kNullOp); if ("none" != grad_req_type) { for (size_t i = 0; i < in_arg_names.size(); ++i) { diff --git a/tests/python/unittest/test_executor.py b/tests/python/unittest/test_executor.py index b190b2898843..c1cc013b81c0 100644 --- a/tests/python/unittest/test_executor.py +++ b/tests/python/unittest/test_executor.py @@ -121,7 +121,7 @@ def test_reshape(): x = mx.sym.Variable('x') y = mx.sym.FullyConnected(x, num_hidden=4) - exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req=[]) + exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req='null') exe.arg_arrays[0][:] = 1 exe.arg_arrays[1][:] = mx.nd.ones((4,4)) exe.arg_arrays[2][:] = 0