From 375dbc4cd3b1108aa6df668b28954b559576186f Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 31 Aug 2020 16:41:24 +0000 Subject: [PATCH 1/2] Support input nodes with multiple data entries --- src/runtime/contrib/json/json_runtime.h | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 92830e663d25..9362a0e256fa 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -150,7 +150,7 @@ class JSONRuntimeBase : public ModuleNode { << "Found mismatch in the number of provided data entryies and required."; for (size_t i = 0; i < static_cast(args.size()); i++) { - auto eid = i < input_var_idx_.size() ? EntryID(input_var_idx_[i], 0) + auto eid = i < input_var_idx_.size() ? input_var_idx_[i] : EntryID(outputs_[i - input_var_idx_.size()]); CHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() == kTVMDLTensorHandle) << "Expect NDArray or DLTensor as inputs"; @@ -183,7 +183,10 @@ class JSONRuntimeBase : public ModuleNode { uint32_t nid = input_nodes_[i]; std::string name = nodes_[nid].name_; if (nodes_[nid].op_type_ == "input") { - input_var_idx_.push_back(nid); + CHECK_EQ(nodes_[nid].GetOpShape().size(), nodes_[nid].GetOpDataType().size()); + for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { + input_var_idx_.push_back(EntryID(nid, j)); + } } else { CHECK_EQ(nodes_[nid].op_type_, "const"); auto pos = std::find(std::begin(const_names_), std::end(const_names_), name); From fa418f6f875dfa8558507f39dd2d6eaaa0fc7f66 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 31 Aug 2020 23:40:49 +0000 Subject: [PATCH 2/2] Rename input_var_idx_ to input_var_eid_ --- src/runtime/contrib/json/json_runtime.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 9362a0e256fa..9eb7fcd2f689 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -146,12 +146,12 @@ class JSONRuntimeBase : public ModuleNode { * \param args The packed args. */ void SetInputOutputBuffers(const TVMArgs& args) { - CHECK_EQ(args.size(), input_var_idx_.size() + outputs_.size()) + CHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size()) << "Found mismatch in the number of provided data entryies and required."; for (size_t i = 0; i < static_cast(args.size()); i++) { - auto eid = i < input_var_idx_.size() ? input_var_idx_[i] - : EntryID(outputs_[i - input_var_idx_.size()]); + auto eid = i < input_var_eid_.size() ? input_var_eid_[i] + : EntryID(outputs_[i - input_var_eid_.size()]); CHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() == kTVMDLTensorHandle) << "Expect NDArray or DLTensor as inputs"; @@ -185,7 +185,7 @@ class JSONRuntimeBase : public ModuleNode { if (nodes_[nid].op_type_ == "input") { CHECK_EQ(nodes_[nid].GetOpShape().size(), nodes_[nid].GetOpDataType().size()); for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { - input_var_idx_.push_back(EntryID(nid, j)); + input_var_eid_.push_back(EntryID(nid, j)); } } else { CHECK_EQ(nodes_[nid].op_type_, "const"); @@ -264,8 +264,8 @@ class JSONRuntimeBase : public ModuleNode { std::vector outputs_; /*! \brief Data of that entry. */ std::vector data_entry_; - /*! \brief Map the input name to node index. */ - std::vector input_var_idx_; + /*! \brief Map the input name to entry id. */ + std::vector input_var_eid_; /*! \brief input const node index. */ std::vector const_idx_; /*! \brief Indicate if the engine has been initialized. */