From f542fe026c7247137c6daefc3f780e20ecbd26e4 Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 2 Jun 2017 13:29:58 -0700 Subject: [PATCH] Add printing error messages for shape/type inference failure --- python/mxnet/symbol.py | 68 ++++++++++++++------------ src/executor/graph_executor.cc | 87 +++++++++++++++++++++++++++------- 2 files changed, 106 insertions(+), 49 deletions(-) diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 0a26afec4731..d1f52b4b48f5 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -1336,37 +1336,43 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, num_aux_states = ctypes.c_uint() aux_state_handles = ctypes.POINTER(NDArrayHandle)() - check_call(_LIB.MXExecutorSimpleBind(self.handle, - ctypes.c_int(ctx.device_typeid), - ctypes.c_int(ctx.device_id), - num_ctx_map_keys, - ctx_map_keys, - ctx_map_dev_types, - ctx_map_dev_ids, - mx_uint(provided_req_type_list_len), - provided_grad_req_names, - provided_grad_req_types, - mx_uint(len(provided_arg_shape_names)), - c_array(ctypes.c_char_p, provided_arg_shape_names), - c_array(mx_uint, provided_arg_shape_data), - c_array(mx_uint, provided_arg_shape_idx), - num_provided_arg_types, - provided_arg_type_names, - provided_arg_type_data, - mx_uint(len(shared_arg_name_list)), - c_array(ctypes.c_char_p, shared_arg_name_list), - ctypes.byref(shared_buffer_len), - shared_buffer_names, - shared_buffer_handles, - ctypes.byref(updated_shared_buffer_names), - ctypes.byref(updated_shared_buffer_handles), - ctypes.byref(num_in_args), - ctypes.byref(in_arg_handles), - ctypes.byref(arg_grad_handles), - ctypes.byref(num_aux_states), - ctypes.byref(aux_state_handles), - shared_exec_handle, - ctypes.byref(exe_handle))) + try: + check_call(_LIB.MXExecutorSimpleBind(self.handle, + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + num_ctx_map_keys, + ctx_map_keys, + ctx_map_dev_types, + ctx_map_dev_ids, + mx_uint(provided_req_type_list_len), + provided_grad_req_names, + provided_grad_req_types, + mx_uint(len(provided_arg_shape_names)), + c_array(ctypes.c_char_p, provided_arg_shape_names), + c_array(mx_uint, provided_arg_shape_data), + c_array(mx_uint, provided_arg_shape_idx), + num_provided_arg_types, + provided_arg_type_names, + provided_arg_type_data, + mx_uint(len(shared_arg_name_list)), + c_array(ctypes.c_char_p, shared_arg_name_list), + ctypes.byref(shared_buffer_len), + shared_buffer_names, + shared_buffer_handles, + ctypes.byref(updated_shared_buffer_names), + ctypes.byref(updated_shared_buffer_handles), + ctypes.byref(num_in_args), + ctypes.byref(in_arg_handles), + ctypes.byref(arg_grad_handles), + ctypes.byref(num_aux_states), + ctypes.byref(aux_state_handles), + shared_exec_handle, + ctypes.byref(exe_handle))) + except MXNetError: + print("simple_bind error. Arguments:") + for k, v in kwargs.items(): + print(" %s: %s" % (k, v)) + raise RuntimeError('simple_bind failed') # update shared_buffer if shared_buffer is not None: diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 3bc812f2bf3c..de5b1ed06016 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -358,6 +358,53 @@ Graph AssignContext(Graph g, return g; } +void HandleInferShapeError(const size_t num_forward_inputs, + const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes) { + int cnt = 10; + std::ostringstream oss; + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const TShape& inferred_shape = inferred_shapes[eid]; + if (inferred_shape.ndim() == 0 || inferred_shape.Size() == 0U) { + const std::string& arg_name = idx[nid].source->attrs.name; + oss << arg_name << ": " << inferred_shape << ", "; + if (--cnt == 0) { + oss << "..."; + break; + } + } + } + LOG(FATAL) << "InferShape pass cannot decide shapes for the following arguments " + "(0s in shapes mean unknown dimension size). Please consider " + "providing them as inputs:\n" + << oss.str(); +} + +void HandleInferTypeError(const size_t num_forward_inputs, + const nnvm::IndexedGraph& idx, + const nnvm::DTypeVector& inferred_dtypes) { + int cnt = 10; + std::ostringstream oss; + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const int inferred_dtype = inferred_dtypes[eid]; + if (inferred_dtype == -1) { + const std::string& arg_name = idx[nid].source->attrs.name; + oss << arg_name << ": " << inferred_dtype << ", "; + if (--cnt == 0) { + oss << "..."; + break; + } + } + } + LOG(FATAL) << "InferType pass cannot decide dtypes for the following arguments " + "(-1 means unknown dtype). Please consider providing them as inputs:\n" + << oss.str(); +} + /*! * \brief GraphExecutor initializer for regular bind flow in which * input arguments and gradients are provided by users. This initializer @@ -390,7 +437,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, // create arg_shapes and arg_dtypes for shape and type inferences const auto& idx = g.indexed_graph(); - auto mutable_nodes = idx.mutable_input_nodes(); + const auto& mutable_nodes = idx.mutable_input_nodes(); size_t arg_top = 0, aux_top = 0; data_entry_.resize(idx.num_node_entries()); nnvm::ShapeVector arg_shapes; @@ -421,16 +468,18 @@ void GraphExecutor::Init(nnvm::Symbol symbol, // expand arg_shapes and arg_dtypes to contain backward inputs arg_shapes.resize(idx.input_nodes().size(), TShape()); - arg_dtypes.resize(idx.input_nodes().size(), -1); - // Infer shapes and dtypes g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); - CHECK_EQ(g.GetAttr("shape_num_unknown_nodes"), 0) - << "Shape inference failed in bind. Please provide" - " sufficient shapes to make inference for the symbol"; + if (g.GetAttr("shape_num_unknown_nodes") != 0U) { + HandleInferShapeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("shape")); + } + + arg_dtypes.resize(idx.input_nodes().size(), -1); g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); - CHECK_EQ(g.GetAttr("dtype_num_unknown_nodes"), 0) - << "Type inference failed in bind. Please provide" - " sufficcient types to make inference for the symbol"; + if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { + HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("dtype")); + } // Initialize the rest attributes of the graph. // This function can be called by regular bind @@ -458,8 +507,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, // populate grad_store_ data_entry_.resize(idx.num_node_entries()); size_t arg_top = 0, aux_top = 0; - auto mutable_nodes = idx.mutable_input_nodes(); - // TODO(junwu): populate in_arg_map, arg_grad_map, and aux_state_map + const auto& mutable_nodes = idx.mutable_input_nodes(); for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const uint32_t eid = idx.entry_id(nid, 0); @@ -544,7 +592,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, // initialize in_args, arg_grads, and aux_states and populate grad_store_ data_entry_.resize(idx.num_node_entries()); size_t arg_top = 0, aux_top = 0; - auto mutable_nodes = idx.mutable_input_nodes(); + const auto& mutable_nodes = idx.mutable_input_nodes(); for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const uint32_t eid = idx.entry_id(nid, 0); @@ -743,13 +791,16 @@ void GraphExecutor::Init(nnvm::Symbol symbol, } } g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); - CHECK_EQ(g.GetAttr("shape_num_unknown_nodes"), 0) - << "Shape inference failed in simple_bind. Please provide" - " sufficient shapes to make inference for the symbol"; + if (g.GetAttr("shape_num_unknown_nodes") != 0U) { + HandleInferShapeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("shape")); + } + g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); - CHECK_EQ(g.GetAttr("dtype_num_unknown_nodes"), 0) - << "Type inference failed in simple_bind. Please provide" - " sufficcient types to make inference for the symbol"; + if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { + HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("dtype")); + } // Create in_args, arg_grads, and aux_states using // the inferred shapes and dtypes.