From 4acc3fddc8b895553422f82129c66fb7bd147064 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 30 Apr 2020 21:30:52 -0700 Subject: [PATCH] feat(//core/lowering): New freeze model pass and new exception elimination pass Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/converters/impl/linear.cpp | 7 ++ core/lowering/BUILD | 3 +- core/lowering/lowering.cpp | 35 +++++--- core/lowering/lowering.h | 2 +- core/lowering/{irfusers => passes}/BUILD | 8 +- .../lowering/passes/exception_elimination.cpp | 86 +++++++++++++++++++ .../expand_log_softmax.cpp | 10 +-- .../fuse_flatten_linear.cpp | 40 ++++++++- .../{irfusers/irfusers.h => passes/passes.h} | 3 +- .../{irfusers => passes}/remove_dropout.cpp | 6 +- .../unpack_batch_norm.cpp | 4 +- core/util/logging/TRTorchLogger.cpp | 2 +- core/util/macros.h | 6 +- tests/modules/hub.py | 6 +- tests/modules/test_compiled_modules.cpp | 6 +- tests/modules/test_modules_as_engines.cpp | 6 +- .../test_multiple_registered_engines.cpp | 4 +- 17 files changed, 188 insertions(+), 46 deletions(-) rename core/lowering/{irfusers => passes}/BUILD (74%) create mode 100644 core/lowering/passes/exception_elimination.cpp rename core/lowering/{irfusers => passes}/expand_log_softmax.cpp (95%) rename core/lowering/{irfusers => passes}/fuse_flatten_linear.cpp (52%) rename core/lowering/{irfusers/irfusers.h => passes/passes.h} (82%) rename core/lowering/{irfusers => passes}/remove_dropout.cpp (93%) rename core/lowering/{irfusers => passes}/unpack_batch_norm.cpp (98%) diff --git a/core/conversion/converters/impl/linear.cpp b/core/conversion/converters/impl/linear.cpp index 0b5b4e957a..faca5c3edd 100644 --- a/core/conversion/converters/impl/linear.cpp +++ b/core/conversion/converters/impl/linear.cpp @@ -9,6 +9,13 @@ namespace impl { namespace { auto linear_registrations = RegisterNodeConversionPatterns() + // .pattern({ + // "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> (Tensor)", + // [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> { + // auto in = args[0].ITensor(); + + // } + // }) .pattern({ "aten::linear(Tensor input, Tensor weight, Tensor? bias = None) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { diff --git a/core/lowering/BUILD b/core/lowering/BUILD index 9a0c53c13b..4cb1978251 100644 --- a/core/lowering/BUILD +++ b/core/lowering/BUILD @@ -11,7 +11,8 @@ cc_library( ], deps = [ "@libtorch//:libtorch", - "//core/lowering/irfusers" + "//core/lowering/passes", + "//core/util:prelude" ] ) diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index d18a9d612d..53ff6d36d8 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -1,10 +1,13 @@ #include "torch/csrc/jit/passes/dead_code_elimination.h" #include "torch/csrc/jit/passes/fuse_linear.h" +#include "torch/csrc/jit/passes/freeze_module.h" #include "torch/csrc/jit/passes/lower_graph.h" #include "torch/csrc/jit/passes/quantization.h" +#include "torch/csrc/jit/passes/guard_elimination.h" +#include "core/util/prelude.h" #include "core/lowering/lowering.h" -#include "core/lowering/irfusers/irfusers.h" +#include "core/lowering/passes/passes.h" namespace trtorch { namespace core { @@ -17,30 +20,36 @@ void LowerBlock(torch::jit::Block* b) { } void LowerGraph(std::shared_ptr& g) { + torch::jit::EliminateRedundantGuards(g); + passes::EliminateExceptionOrPassPattern(g); torch::jit::FuseLinear(g); - irfusers::RemoveDropout(g); - irfusers::FuseFlattenLinear(g); - irfusers::ExpandLogSoftmax(g); + passes::RemoveDropout(g); + passes::FuseFlattenLinear(g); + passes::ExpandLogSoftmax(g); + //passes::RemoveDimExeception(g); //irfusers::UnpackBatchNorm(g); - //torch::jit::EliminateDeadCode(g); + torch::jit::EliminateDeadCode(g); + LOG_GRAPH(*g); } -void LowerModule(const torch::jit::script::Module& mod) { - torch::jit::FoldConvBatchNorm2d(mod); +torch::jit::Module LowerModule(const torch::jit::script::Module& mod) { + auto mod_ = torch::jit::freeze_module(mod); + return mod_; } std::pair, std::vector> Lower(const torch::jit::script::Module& mod, std::string method_name) { - LowerModule(mod); - auto g = mod.get_method(method_name).graph(); - // Go through PyTorch Lowering to simplify graph and extract weight parameters - auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue()); - - g = graph_and_parameters.first; + auto lowered_mod = LowerModule(mod); + auto g = lowered_mod.get_method(method_name).graph(); + LOG_GRAPH(*g); // Go through TRTorch Lowering to reformat graph to be conversion friendly // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT) + LOG_GRAPH("TRTorch Graph Lowering"); lowering::LowerGraph(g); + //=[torch::jit::FoldConvBatchNorm2d(lowered_mod); + LOG_GRAPH("LibTorch Lowering"); + auto graph_and_parameters = torch::jit::LowerGraph(*g, lowered_mod._ivalue()); // Is this necessary? lowering::LowerBlock(g->block()); return graph_and_parameters; diff --git a/core/lowering/lowering.h b/core/lowering/lowering.h index 8ee2cdde53..79f07cb5ec 100644 --- a/core/lowering/lowering.h +++ b/core/lowering/lowering.h @@ -8,7 +8,7 @@ namespace lowering { void LowerBlock(torch::jit::Block* b); void LowerGraph(std::shared_ptr& g); -void LowerModule(const torch::jit::script::Module& mod); +torch::jit::Module LowerModule(const torch::jit::script::Module& mod); std::pair, std::vector> Lower(const torch::jit::script::Module& mod, std::string method_name); diff --git a/core/lowering/irfusers/BUILD b/core/lowering/passes/BUILD similarity index 74% rename from core/lowering/irfusers/BUILD rename to core/lowering/passes/BUILD index 71899dfee6..0de638898c 100644 --- a/core/lowering/irfusers/BUILD +++ b/core/lowering/passes/BUILD @@ -1,17 +1,19 @@ package(default_visibility = ["//visibility:public"]) cc_library( - name = "irfusers", + name = "passes", hdrs = [ - "irfusers.h", + "passes.h", ], srcs = [ "fuse_flatten_linear.cpp", "expand_log_softmax.cpp", "remove_dropout.cpp", - "unpack_batch_norm.cpp" + "unpack_batch_norm.cpp", + "exception_elimination.cpp" ], deps = [ + "//core/util:prelude", "@libtorch//:libtorch", ] ) diff --git a/core/lowering/passes/exception_elimination.cpp b/core/lowering/passes/exception_elimination.cpp new file mode 100644 index 0000000000..54f36a833a --- /dev/null +++ b/core/lowering/passes/exception_elimination.cpp @@ -0,0 +1,86 @@ +#include "torch/csrc/jit/passes/guard_elimination.h" +#include "torch/csrc/jit/ir/alias_analysis.h" +#include "torch/csrc/jit/jit_log.h" +#include "torch/csrc/jit/passes/constant_propagation.h" +#include "torch/csrc/jit/passes/peephole.h" +#include "torch/csrc/jit/runtime/graph_executor.h" +#include "torch/csrc/jit/passes/dead_code_elimination.h" + +#include "core/util/prelude.h" + +#include + +namespace trtorch { +namespace core { +namespace lowering { +namespace passes { +namespace { +using namespace torch::jit; +struct ExceptionOrPassPatternElimination { + ExceptionOrPassPatternElimination(std::shared_ptr graph) + : graph_(std::move(graph)) {} + + void run() { + LOG_GRAPH("Pre exeception or pass elimination: " << *graph_); + findExceptionOrPassNodes(graph_->block()); + torch::jit::EliminateDeadCode(graph_); + LOG_GRAPH("Post exeception or pass elimination: " << *graph_); + } + +private: + bool isExceptionOrPassNode(Node* n) { + /// Check if this Node hosts a pattern like so: + /// = prim::If(%5958) + /// block0(): + /// = prim::RaiseException(%45) + /// -> () + /// block1(): + /// -> () + if (n->blocks().size() != 2) { + return false; + } + auto arm1 = n->blocks()[0]; + auto arm2 = n->blocks()[1]; + if (arm1->outputs().size() != 0 || arm2->outputs().size() != 0) { + // Make sure that the node doesn't actually produce any Value that are used by other nodes + return false; + } + + auto arm1_start = arm1->nodes().begin(); + + if ((*arm1_start)->kind() != prim::RaiseException && (*(++arm1_start))->kind() != prim::Return) { + // Make sure that block0 is solely just the exception and the return + return false; + } + + if ((*(arm2->nodes().begin()))->kind() != prim::Return) { + // Make sure that block1 is solely the return + return false; + } + + return true; + } + + void findExceptionOrPassNodes(Block* b) { + for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) { + auto n = *it; + if (n->kind() == prim::If && isExceptionOrPassNode(n)) { + LOG_GRAPH("Found that node " << *n << " is an exception or pass node (EliminateChecks)"); + it.destroyCurrent(); + } + } + } + + std::shared_ptr graph_; +}; +} // namespace + +void EliminateExceptionOrPassPattern(std::shared_ptr graph) { + ExceptionOrPassPatternElimination eppe(std::move(graph)); + eppe.run(); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace trtorch diff --git a/core/lowering/irfusers/expand_log_softmax.cpp b/core/lowering/passes/expand_log_softmax.cpp similarity index 95% rename from core/lowering/irfusers/expand_log_softmax.cpp rename to core/lowering/passes/expand_log_softmax.cpp index c9c3cc27ce..37bbcdd321 100644 --- a/core/lowering/irfusers/expand_log_softmax.cpp +++ b/core/lowering/passes/expand_log_softmax.cpp @@ -4,14 +4,14 @@ namespace trtorch { namespace core { namespace lowering { -namespace irfusers { +namespace passes { void ExpandLogSoftmax(std::shared_ptr& graph) { // Its easier for TensorRT if we seperate softmax and log // There might need to be a reshape inserted see: // https://github.com/onnx/onnx-tensorrt/blob/5dca8737851118f6ab8a33ea1f7bcb7c9f06caf5/builtin_op_importers.cpp#L1593 // Should the reshapes be added here or in the converter? - + // TODO: In the future this should be removed for a deicated log_softmax converter (more efficent) // But its easier to stand up a working system if the number of op converters is lower std::string logsoftmax_pattern = R"IR( @@ -33,19 +33,19 @@ void ExpandLogSoftmax(std::shared_ptr& graph) { %dtype : int? = prim::Constant() %softmax = aten::softmax(%input, %dim, %dtype) %log_softmax = aten::log(%softmax) - return (%log_softmax))IR"; + return (%log_softmax))IR"; torch::jit::SubgraphRewriter logsoftmax_to_softmax_log; logsoftmax_to_softmax_log.RegisterRewritePattern(logsoftmax_pattern, softmax_log_pattern); logsoftmax_to_softmax_log.runOnGraph(graph); - + torch::jit::SubgraphRewriter logsoftmax_none_to_softmax_log_none; logsoftmax_none_to_softmax_log_none.RegisterRewritePattern( logsoftmax_none_pattern, softmax_log_none_pattern); logsoftmax_none_to_softmax_log_none.runOnGraph(graph); } -} // namespace irfusers +} // namespace passes } // namespace lowering } // namespace core } // namespace trtorch diff --git a/core/lowering/irfusers/fuse_flatten_linear.cpp b/core/lowering/passes/fuse_flatten_linear.cpp similarity index 52% rename from core/lowering/irfusers/fuse_flatten_linear.cpp rename to core/lowering/passes/fuse_flatten_linear.cpp index 5b8c3899ec..6dc8ebf68e 100644 --- a/core/lowering/irfusers/fuse_flatten_linear.cpp +++ b/core/lowering/passes/fuse_flatten_linear.cpp @@ -4,7 +4,7 @@ namespace trtorch { namespace core { namespace lowering { -namespace irfusers { +namespace passes { void FuseFlattenLinear(std::shared_ptr& graph) { //TensorRT implicitly adds a flatten layer infront of FC layers if necessary @@ -33,13 +33,47 @@ void FuseFlattenLinear(std::shared_ptr& graph) { torch::jit::SubgraphRewriter flatten_linear_to_linear; flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear); flatten_linear_to_linear.runOnGraph(graph); - + + torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear; + flatten_linear_bias_none_to_linear.RegisterRewritePattern( + flatten_linear_bias_none_pattern, fused_linear_bias_none); + flatten_linear_bias_none_to_linear.runOnGraph(graph); +} + +void FuseFlattenAddMM(std::shared_ptr& graph) { + //TensorRT implicitly adds a flatten layer infront of FC layers if necessary + std::string flatten_linear_pattern = R"IR( + graph(%input, %6, %7, %weight, %bias): + %flat = aten::flatten(%input, %6, %7) + %res = aten::linear(%flat, %weight, %bias) + return (%res))IR"; + std::string flatten_linear_bias_none_pattern = R"IR( + graph(%input, %6, %7, %weight): + %flat = aten::flatten(%input, %6, %7) + %bias: Tensor? = prim::Constant() + %res = aten::linear(%flat, %weight, %bias) + return (%res))IR"; + std::string fused_linear = R"IR( + graph(%input, %6, %7, %weight, %bias): + %res = aten::linear(%input, %weight, %bias) + return (%res))IR"; + + std::string fused_linear_bias_none = R"IR( + graph(%input, %6, %7, %weight): + %bias: Tensor? = prim::Constant() + %res = aten::linear(%input, %weight, %bias) + return (%res))IR"; + + torch::jit::SubgraphRewriter flatten_linear_to_linear; + flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear); + flatten_linear_to_linear.runOnGraph(graph); + torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear; flatten_linear_bias_none_to_linear.RegisterRewritePattern( flatten_linear_bias_none_pattern, fused_linear_bias_none); flatten_linear_bias_none_to_linear.runOnGraph(graph); } -} // namespace irfusers +} // namespace passes } // namespace lowering } // namespace core } // namespace trtorch diff --git a/core/lowering/irfusers/irfusers.h b/core/lowering/passes/passes.h similarity index 82% rename from core/lowering/irfusers/irfusers.h rename to core/lowering/passes/passes.h index c48f5d112d..1ef02ac065 100644 --- a/core/lowering/irfusers/irfusers.h +++ b/core/lowering/passes/passes.h @@ -5,12 +5,13 @@ namespace trtorch { namespace core { namespace lowering { -namespace irfusers { +namespace passes { void FuseFlattenLinear(std::shared_ptr& graph); void ExpandLogSoftmax(std::shared_ptr& graph); void RemoveDropout(std::shared_ptr& graph); void UnpackBatchNorm(std::shared_ptr& graph); +void EliminateExceptionOrPassPattern(std::shared_ptr graph); } // namespace irfusers } // namespace lowering diff --git a/core/lowering/irfusers/remove_dropout.cpp b/core/lowering/passes/remove_dropout.cpp similarity index 93% rename from core/lowering/irfusers/remove_dropout.cpp rename to core/lowering/passes/remove_dropout.cpp index fcd1d07aa0..b28fe011a8 100644 --- a/core/lowering/irfusers/remove_dropout.cpp +++ b/core/lowering/passes/remove_dropout.cpp @@ -4,7 +4,7 @@ namespace trtorch { namespace core { namespace lowering { -namespace irfusers { +namespace passes { void RemoveDropout(std::shared_ptr& graph) { std::string dropout_pattern = R"IR( @@ -14,7 +14,7 @@ void RemoveDropout(std::shared_ptr& graph) { std::string no_dropout_pattern = R"IR( graph(%input, %4, %5): return (%input))IR"; - + // replace matmul + add pattern to linear torch::jit::SubgraphRewriter remove_dropout; remove_dropout.RegisterRewritePattern( @@ -22,7 +22,7 @@ void RemoveDropout(std::shared_ptr& graph) { remove_dropout.runOnGraph(graph); } -} // namespace irfusers +} // namespace passes } // namespace lowering } // namespace core } // namespace trtorch diff --git a/core/lowering/irfusers/unpack_batch_norm.cpp b/core/lowering/passes/unpack_batch_norm.cpp similarity index 98% rename from core/lowering/irfusers/unpack_batch_norm.cpp rename to core/lowering/passes/unpack_batch_norm.cpp index 8a10d747f8..b75af2c2f9 100644 --- a/core/lowering/irfusers/unpack_batch_norm.cpp +++ b/core/lowering/passes/unpack_batch_norm.cpp @@ -23,7 +23,7 @@ RegisterOperators trt_const_op_reg({ namespace trtorch { namespace core { namespace lowering { -namespace irfusers { +namespace passes { // // May be abusing aten::_tensor_to_list(Tensor self) -> int[] // // Treating it as an emit_constant by the converters @@ -60,7 +60,7 @@ void UnpackBatchNorm(std::shared_ptr& graph) { unpack_batch_norm.RegisterRewritePattern(batch_norm_pattern, expanded_batch_norm_pattern); unpack_batch_norm.runOnGraph(graph); } -} // Namespace Irfusers +} // Namespace passes } // namespace lowering } // namespace core } // namespace trtorch diff --git a/core/util/logging/TRTorchLogger.cpp b/core/util/logging/TRTorchLogger.cpp index 678506c09f..7c80152076 100644 --- a/core/util/logging/TRTorchLogger.cpp +++ b/core/util/logging/TRTorchLogger.cpp @@ -101,7 +101,7 @@ namespace { TRTorchLogger& get_global_logger() { #ifndef NDEBUG static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", - LogLevel::kDEBUG, + LogLevel::kGRAPH, true); #else static TRTorchLogger global_logger("[TRTorch] - ", diff --git a/core/util/macros.h b/core/util/macros.h index 7a1d1b4540..ccea57e513 100644 --- a/core/util/macros.h +++ b/core/util/macros.h @@ -11,21 +11,21 @@ l.log(sev, ss.str()); \ } while (0) -#define GRAPH_DUMP_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kGRAPH, s) +#define LOG_GRAPH_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kGRAPH, s) #define LOG_DEBUG_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kDEBUG, s) #define LOG_INFO_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kINFO, s) #define LOG_WARNING_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kWARNING, s) #define LOG_ERROR_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kERROR, s) #define LOG_INTERNAL_ERROR_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kINTERNAL_ERROR, s) -#define GRAPH_DUMP_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kGRAPH, s) +#define LOG_GRAPH_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kGRAPH, s) #define LOG_DEBUG_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kDEBUG, s) #define LOG_INFO_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kINFO, s) #define LOG_WARNING_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kWARNING, s) #define LOG_ERROR_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kERROR, s) #define LOG_INTERNAL_ERROR_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kINTERNAL_ERROR, s) -#define GRAPH_DUMP(...) GET_MACRO(__VA_ARGS__, GRAPH_DUMP_OWN, GRAPH_DUMP_GLOBAL)(__VA_ARGS__) +#define LOG_GRAPH(...) GET_MACRO(__VA_ARGS__, LOG_GRAPH_OWN, LOG_GRAPH_GLOBAL)(__VA_ARGS__) #define LOG_DEBUG(...) GET_MACRO(__VA_ARGS__, LOG_DEBUG_OWN, LOG_DEBUG_GLOBAL)(__VA_ARGS__) #define LOG_INFO(...) GET_MACRO(__VA_ARGS__, LOG_INFO_OWN, LOG_INFO_GLOBAL)(__VA_ARGS__) #define LOG_WARNING(...) GET_MACRO(__VA_ARGS__, LOG_WARNING_OWN, LOG_WARNING_GLOBAL)(__VA_ARGS__) diff --git a/tests/modules/hub.py b/tests/modules/hub.py index b1b3a1a102..b3dc394478 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -20,5 +20,7 @@ print("Downloading {}".format(n)) m = m.eval().cuda() x = torch.ones((1, 3, 224, 224)).cuda() - jit_model = torch.jit.trace(m, x) - torch.jit.save(jit_model, n + '.jit.pt') + trace_model = torch.jit.trace(m, x) + torch.jit.save(trace_model, n + '_traced.jit.pt') + script_model = torch.jit.script(m) + torch.jit.save(script_model, n + '_scripted.jit.pt') \ No newline at end of file diff --git a/tests/modules/test_compiled_modules.cpp b/tests/modules/test_compiled_modules.cpp index 199e1d81b5..9a5c9daf1d 100644 --- a/tests/modules/test_compiled_modules.cpp +++ b/tests/modules/test_compiled_modules.cpp @@ -28,9 +28,9 @@ TEST_P(ModuleTests, CompiledModuleIsClose) { INSTANTIATE_TEST_SUITE_P(CompiledModuleForwardIsCloseSuite, ModuleTests, testing::Values( - PathAndInSize({"tests/modules/resnet18.jit.pt", + PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1,3,224,224}}}), - PathAndInSize({"tests/modules/resnet50.jit.pt", + PathAndInSize({"tests/modules/resnet50_traced.jit.pt", {{1,3,224,224}}}), - PathAndInSize({"tests/modules/mobilenet_v2.jit.pt", + PathAndInSize({"tests/modules/mobilenet_v2_traced.jit.pt", {{1,3,224,224}}}))); diff --git a/tests/modules/test_modules_as_engines.cpp b/tests/modules/test_modules_as_engines.cpp index 759ed136c9..d190251bb3 100644 --- a/tests/modules/test_modules_as_engines.cpp +++ b/tests/modules/test_modules_as_engines.cpp @@ -19,9 +19,9 @@ TEST_P(ModuleTests, ModuleAsEngineIsClose) { INSTANTIATE_TEST_SUITE_P(ModuleAsEngineForwardIsCloseSuite, ModuleTests, testing::Values( - PathAndInSize({"tests/modules/resnet18.jit.pt", + PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1,3,224,224}}}), - PathAndInSize({"tests/modules/resnet50.jit.pt", + PathAndInSize({"tests/modules/resnet50_traced.jit.pt", {{1,3,224,224}}}), - PathAndInSize({"tests/modules/mobilenet_v2.jit.pt", + PathAndInSize({"tests/modules/mobilenet_v2_traced.jit.pt", {{1,3,224,224}}}))); \ No newline at end of file diff --git a/tests/modules/test_multiple_registered_engines.cpp b/tests/modules/test_multiple_registered_engines.cpp index 7ce3dbf61f..c03e68c0b4 100644 --- a/tests/modules/test_multiple_registered_engines.cpp +++ b/tests/modules/test_multiple_registered_engines.cpp @@ -8,8 +8,8 @@ TEST(ModuleTests, CanRunMultipleEngines) { torch::jit::script::Module mod1; torch::jit::script::Module mod2; try { - mod1 = torch::jit::load("tests/modules/resnet50.jit.pt"); - mod2 = torch::jit::load("tests/modules/resnet18.jit.pt"); + mod1 = torch::jit::load("tests/modules/resnet50_traced.jit.pt"); + mod2 = torch::jit::load("tests/modules/resnet18_traced.jit.pt"); } catch (const c10::Error& e) { std::cerr << "error loading the model\n";