From fef4f3b75614665dfd362f6c32d28b01549b6f8e Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 1 Jun 2017 20:32:22 -0700 Subject: [PATCH] Add checks for shape/type inferences --- src/executor/graph_executor.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 1623bfdc576b..3bc812f2bf3c 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -424,7 +424,13 @@ void GraphExecutor::Init(nnvm::Symbol symbol, arg_dtypes.resize(idx.input_nodes().size(), -1); // Infer shapes and dtypes g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); + CHECK_EQ(g.GetAttr("shape_num_unknown_nodes"), 0) + << "Shape inference failed in bind. Please provide" + " sufficient shapes to make inference for the symbol"; g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); + CHECK_EQ(g.GetAttr("dtype_num_unknown_nodes"), 0) + << "Type inference failed in bind. Please provide" + " sufficcient types to make inference for the symbol"; // Initialize the rest attributes of the graph. // This function can be called by regular bind @@ -737,7 +743,13 @@ void GraphExecutor::Init(nnvm::Symbol symbol, } } g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); + CHECK_EQ(g.GetAttr("shape_num_unknown_nodes"), 0) + << "Shape inference failed in simple_bind. Please provide" + " sufficient shapes to make inference for the symbol"; g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); + CHECK_EQ(g.GetAttr("dtype_num_unknown_nodes"), 0) + << "Type inference failed in simple_bind. Please provide" + " sufficcient types to make inference for the symbol"; // Create in_args, arg_grads, and aux_states using // the inferred shapes and dtypes.