Skip to content

Commit

Permalink
Merge pull request #1148 from pytorch/fix_parsing
Browse files Browse the repository at this point in the history
fix: fix the parsing related model loading bug
  • Loading branch information
peri044 authored Jul 25, 2022
2 parents 5cb5947 + e7c359d commit e07687d
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 2 deletions.
6 changes: 5 additions & 1 deletion core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,11 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
auto graph_and_mapping =
ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params, fallback_nodes);
new_g = graph_and_mapping.first;
LOG_INFO("Segmented Graph: " << *new_g);
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
for (size_t i = 0; i < new_g->inputs().size(); ++i) {
new_g->inputs()[i]->setDebugName(std::string("input_") + std::to_string(i));
}
LOG_INFO(*new_g << "(GraphAfterFallback)");

// if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
// module
Expand Down
18 changes: 17 additions & 1 deletion tests/core/partitioning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ partitioning_test(
name = "test_resolve_nontensor_inputs",
)

cc_test(
name = "test_loading_model",
srcs = ["test_loading_model.cpp"],
deps = [
"//tests/util",
"@googletest//:gtest_main",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
}),
data = [
":jit_models"
]
)

cc_test(
name = "test_fallback_graph_output",
srcs = ["test_fallback_graph_output.cpp"],
Expand Down Expand Up @@ -92,6 +107,7 @@ test_suite(
":test_fallback_graph_output",
":test_loop_fallback",
":test_conditionals",
":test_resolve_nontensor_inputs"
":test_resolve_nontensor_inputs",
":test_loading_model"
]
)
39 changes: 39 additions & 0 deletions tests/core/partitioning/test_loading_model.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include <string>
#include <unordered_set>
#include "core/compiler.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/script.h"

#ifndef DISABLE_TEST_IN_CI

TEST(Partitioning, ComputeResNet50FallbackGraphCorrectly) {
torch::jit::script::Module mod;
try {
mod = torch::jit::load("tests/modules/conditional_scripted.jit.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
}

const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}};
std::vector<torch::jit::IValue> jit_inputs_ivalues;
std::vector<torch::jit::IValue> trt_inputs_ivalues;
for (auto in_shape : input_shapes) {
auto in = at::randint(5, in_shape, {at::kCUDA});
jit_inputs_ivalues.push_back(in.clone());
trt_inputs_ivalues.push_back(in.clone());
}

std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input({1, 3, 224, 224})};

torch_tensorrt::core::CompileSpec cfg(input_ranges);
cfg.partition_info.enabled = true;

auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
trt_mod.save("loading_model.ts");
auto loaded_model = torch::jit::load("loading_model.ts");
}

#endif

0 comments on commit e07687d

Please sign in to comment.