From 3cebe97d09f14777a60938800cd67a5a5e2dc802 Mon Sep 17 00:00:00 2001 From: inocsin Date: Wed, 31 Mar 2021 14:57:14 +0800 Subject: [PATCH] feat: support prim::Param for input type after refactor Signed-off-by: inocsin --- core/partitioning/shape_analysis.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index efb9fd6b6b..513124eb3d 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -64,7 +64,9 @@ void getSegmentsOutputByRunning( // set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments for (auto& input : seg_block.raw_inputs()) { TRTORCH_CHECK(ivalues_maps.count(input), "Could not find mini graph input IValue " << input->debugName()); - if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) { + if (input->node()->kind() == torch::jit::prim::Param) { + jit_inputs_ivalues.push_back(ivalues_maps[input]); + } else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) { jit_inputs_ivalues.push_back(ivalues_maps[input].toTensor()); } else if (input->type()->isSubtypeOf(torch::jit::IntType::get())) { jit_inputs_ivalues.push_back(ivalues_maps[input].toInt());