diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 3424a2ea97..b234117316 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -7,7 +7,6 @@ #include "torch/csrc/jit/passes/lower_graph.h" #include "torch/csrc/jit/passes/lower_tuples.h" #include "torch/csrc/jit/passes/peephole.h" -#include "torch/csrc/jit/passes/quantization.h" #include "core/util/prelude.h" #include "core/lowering/lowering.h" @@ -50,8 +49,7 @@ torch::jit::Module LowerModule(const torch::jit::script::Module& mod) { return mod_; } -std::pair, std::vector> Lower(const torch::jit::script::Module& mod, - std::string method_name) { +std::pair, std::vector> Lower(const torch::jit::script::Module& mod, std::string method_name) { auto lowered_mod = LowerModule(mod); auto g = lowered_mod.get_method(method_name).graph(); LOG_GRAPH(*g); @@ -62,9 +60,14 @@ std::pair, std::vector> Lower(con lowering::LowerGraph(g); //=[torch::jit::FoldConvBatchNorm2d(lowered_mod); LOG_GRAPH("LibTorch Lowering"); - auto graph_and_parameters = torch::jit::LowerGraph(*g, lowered_mod._ivalue()); + auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue()); // Is this necessary? lowering::LowerBlock(g->block()); + std::pair, std::vector> graph_and_parameters; + for (auto i : graph_and_ivalues.second) { + graph_and_parameters.second.push_back(i.toTensor()); + } + graph_and_parameters.first = graph_and_ivalues.first; return graph_and_parameters; }