Skip to content

Commit

Permalink
[BYOC][JSON] Support input nodes with multiple entries (#6368)
Browse files Browse the repository at this point in the history
* Support input nodes with multiple data entries

* Rename input_var_idx_ to input_var_eid_
  • Loading branch information
Trevor Morris authored Sep 1, 2020
1 parent f56bb71 commit b8f37ee
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions src/runtime/contrib/json/json_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ class JSONRuntimeBase : public ModuleNode {
* \param args The packed args.
*/
void SetInputOutputBuffers(const TVMArgs& args) {
CHECK_EQ(args.size(), input_var_idx_.size() + outputs_.size())
CHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size())
<< "Found mismatch in the number of provided data entryies and required.";

for (size_t i = 0; i < static_cast<size_t>(args.size()); i++) {
auto eid = i < input_var_idx_.size() ? EntryID(input_var_idx_[i], 0)
: EntryID(outputs_[i - input_var_idx_.size()]);
auto eid = i < input_var_eid_.size() ? input_var_eid_[i]
: EntryID(outputs_[i - input_var_eid_.size()]);
CHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() == kTVMDLTensorHandle)
<< "Expect NDArray or DLTensor as inputs";

Expand Down Expand Up @@ -183,7 +183,10 @@ class JSONRuntimeBase : public ModuleNode {
uint32_t nid = input_nodes_[i];
std::string name = nodes_[nid].name_;
if (nodes_[nid].op_type_ == "input") {
input_var_idx_.push_back(nid);
CHECK_EQ(nodes_[nid].GetOpShape().size(), nodes_[nid].GetOpDataType().size());
for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) {
input_var_eid_.push_back(EntryID(nid, j));
}
} else {
CHECK_EQ(nodes_[nid].op_type_, "const");
auto pos = std::find(std::begin(const_names_), std::end(const_names_), name);
Expand Down Expand Up @@ -261,8 +264,8 @@ class JSONRuntimeBase : public ModuleNode {
std::vector<JSONGraphNodeEntry> outputs_;
/*! \brief Data of that entry. */
std::vector<const DLTensor*> data_entry_;
/*! \brief Map the input name to node index. */
std::vector<uint32_t> input_var_idx_;
/*! \brief Map the input name to entry id. */
std::vector<uint32_t> input_var_eid_;
/*! \brief input const node index. */
std::vector<uint32_t> const_idx_;
/*! \brief Indicate if the engine has been initialized. */
Expand Down

0 comments on commit b8f37ee

Please sign in to comment.