From ec2bbf237fba16c89c9074be664300e09f32b306 Mon Sep 17 00:00:00 2001 From: inocsin Date: Tue, 30 Mar 2021 17:20:18 +0800 Subject: [PATCH] feat: support prim::Param for fallback inputs Signed-off-by: inocsin --- core/partitioning/partitioning.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 7aa8150177..96e3dfe686 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -106,7 +106,9 @@ void registerSegmentInOutIValues( // set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments for (auto& input : seg_block.raw_inputs()) { TRTORCH_CHECK(ivalues_maps.count(input), "Could not find mini graph input IValue " << input->debugName()); - if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) { + if (input->node()->kind() == torch::jit::prim::Param) { + jit_inputs_ivalues.push_back(ivalues_maps[input]); + } else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) { jit_inputs_ivalues.push_back(ivalues_maps[input].toTensor()); } else if (input->type()->isSubtypeOf(torch::jit::IntType::get())) { jit_inputs_ivalues.push_back(ivalues_maps[input].toInt());