From 4da22dbd8c04b0d87ec53072e7d014417e6bc63c Mon Sep 17 00:00:00 2001 From: reminisce Date: Wed, 10 May 2017 15:36:45 -0700 Subject: [PATCH] Bug fix --- src/c_api/c_api_executor.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index c6075b1d144e..2bd4ca000736 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -246,7 +246,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, if (nullptr == provided_arg_dtypes) { // use attr_dict for (const auto& arg_name : in_arg_names) { const auto it = attr_dict.find(arg_name); - if (it == attr_dict.end() || !it->second.count("__ctx_group__")) { + if (it == attr_dict.end() || !it->second.count("__dtype__")) { arg_dtype_map[arg_name] = mshadow::kFloat32; } } @@ -410,6 +410,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, ret->ret_handles.push_back(new NDArray(nd)); } if (in_arg_vec.size() > 0) { + *num_in_args = in_arg_vec.size(); *in_args = &(ret->ret_handles[nd_idx]); nd_idx = ret->ret_handles.size(); } @@ -433,6 +434,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, ret->ret_handles.push_back(new NDArray(nd)); } if (aux_state_vec.size() > 0) { + *num_aux_states = aux_state_vec.size(); *aux_states = &(ret->ret_handles[nd_idx]); nd_idx = ret->ret_handles.size(); }