From 282e98aff1fecf8f0a12f814c7facc4deeab3b2b Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Sat, 9 Oct 2021 08:57:14 -0700 Subject: [PATCH] fix: Fix modules_as_engines test case to use trt_mod instead of pyt_mod Signed-off-by: Dheeraj Peri --- tests/cpp/test_modules_as_engines.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_modules_as_engines.cpp b/tests/cpp/test_modules_as_engines.cpp index 58e769933f..1c7006d173 100644 --- a/tests/cpp/test_modules_as_engines.cpp +++ b/tests/cpp/test_modules_as_engines.cpp @@ -29,7 +29,6 @@ TEST_P(CppAPITests, ModuleToEngineToModuleIsClose) { std::vector jit_results; jit_results.push_back(jit_results_ivalues.toTensor()); - auto forward_graph = mod.get_method("forward"); std::vector> input_ranges; for (auto in : inputs) { input_ranges.push_back(in.sizes()); @@ -43,7 +42,7 @@ TEST_P(CppAPITests, ModuleToEngineToModuleIsClose) { auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", input_ranges); auto trt_mod = trtorch::EmbedEngineInNewModule(engine, compile_spec.device); - torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(mod, inputs_ivalues); + torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, inputs_ivalues); std::vector trt_results; trt_results.push_back(trt_results_ivalues.toTensor()); @@ -61,4 +60,4 @@ INSTANTIATE_TEST_SUITE_P( PathAndInSize({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), PathAndInSize({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), PathAndInSize({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-3}))); \ No newline at end of file + PathAndInSize({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-3})));