Skip to content

Commit

Permalink
Merge branch 'bowa_fallback' of https://github.com/NVIDIA/TRTorch int…
Browse files Browse the repository at this point in the history
…o bowa_fallback
  • Loading branch information
bowang007 committed Mar 30, 2021
2 parents 459a9b9 + 77b4dc7 commit cfc68ce
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ void registerSegmentInOutIValues(
// set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments
for (auto& input : seg_block.raw_inputs()) {
TRTORCH_CHECK(ivalues_maps.count(input), "Could not find mini graph input IValue " << input->debugName());
if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) {
if (input->node()->kind() == torch::jit::prim::Param) {
jit_inputs_ivalues.push_back(ivalues_maps[input]);
} else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) {
jit_inputs_ivalues.push_back(ivalues_maps[input].toTensor());
} else if (input->type()->isSubtypeOf(torch::jit::IntType::get())) {
jit_inputs_ivalues.push_back(ivalues_maps[input].toInt());
Expand Down

0 comments on commit cfc68ce

Please sign in to comment.