Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Repair invalid schema arising from lowering pass #1786

Merged
merged 1 commit into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions core/lowering/passes/remove_unnecessary_casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,48 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
user->destroy();
break;
case c10::aten::div:
// If the first two entries to aten::div are non-Tensors,
// there cannot be a rounding mode specified (3rd entry)
if (!user->inputs()[0]->type()->isSubtypeOf(c10::TensorType::get()) &&
!user->inputs()[1]->type()->isSubtypeOf(c10::TensorType::get()) &&
user->inputs().size() == 3 &&
user->inputs()[2]->type()->isSubtypeOf(c10::StringType::get()) &&
torch::jit::toIValue(user->inputs()[2]).has_value()) {
// Select the first 2 entries of the inputs, corresponding to the values
auto div_args = user->inputs().slice(0, 2);

// Depending on the rounding mode, create the appropriate nodes
if (torch::jit::toIValue(user->inputs()[2]).value().toStringRef() == "trunc") {
// Truncate case (round result towards 0)
torch::jit::Node* new_node_div;
// Create node which simply divides the two entries
new_node_div = g->create(c10::aten::div, div_args, 1);
new_node_div->insertAfter(user);
new_node_div->outputs()[0]->setType(c10::FloatType::get());

// Create node which casts the result to an integer, effectively truncating
new_node = g->create(c10::aten::Int, new_node_div->outputs(), 1);
new_node->insertAfter(new_node_div);
new_node->outputs()[0]->setType(c10::IntType::get());

user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
user->destroy();
break;

} else if (torch::jit::toIValue(user->inputs()[2]).value().toStringRef() == "floor") {
// Floor case (round result down)
// Replace aten::div with aten::floordiv
new_node = g->create(c10::aten::floordiv, div_args, 1);
new_node->insertAfter(user);
new_node->outputs()[0]->setType(c10::IntType::get());

user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
user->destroy();
break;
}
}

default:
new_node = g->create(user->kind(), user->inputs(), 1);
new_node->insertAfter(user);
Expand Down
151 changes: 151 additions & 0 deletions tests/core/lowering/test_remove_unnecessary_casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/csrc/jit/ir/subgraph_matcher.h"
#include "torch/csrc/jit/passes/canonicalize.h"
#include "torch/csrc/jit/passes/constant_pooling.h"

TEST(LoweringPasses, RemoveUnnecessaryCastIntCorrectly) {
std::string source_graph = R"IR(
Expand Down Expand Up @@ -255,6 +257,155 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivIntValuesAgree) {
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor()));
}

TEST(LoweringPasses, RemoveSingleUse0DTensorsDivTruncIntValuesAgree) {
// Ensure the source and target graphs have equivalent outputs
// (Source and Target are computing equivalent values)
std::string source_graph_no_inputs = R"IR(
graph():
%0: int = prim::Constant[value=2]()
%11: int = prim::Constant[value=-3]()
%234 : str = prim::Constant[value="trunc"]()
%3: Tensor = prim::NumToTensor(%0)
%1: Tensor = prim::NumToTensor(%11)
%4: Tensor = aten::div(%1, %3, %234)
%50: int = aten::Int(%4)
%5: Tensor = prim::NumToTensor(%50)
return (%5))IR";
std::string target_graph_no_inputs = R"IR(
graph():
%0: int = prim::Constant[value=2]()
%1: int = prim::Constant[value=-3]()
%40: float = aten::div(%1, %0)
%41: int = aten::Int(%40)
%4: Tensor = prim::NumToTensor(%41)
return (%4))IR";

auto g_in = std::make_shared<torch::jit::Graph>();
auto g_out = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(source_graph_no_inputs, g_in.get());
torch::jit::parseIR(target_graph_no_inputs, g_out.get());

auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {});
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {});

ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor()));

// Ensure the lowering pass transforms the first graph into the second
std::string source_graph = R"IR(
graph(%0: int):
%1: Tensor = prim::Constant[value=[8]]()
%3: Tensor = prim::NumToTensor(%0)
%234: str = prim::Constant[value="trunc"]()
%4: Tensor = aten::div(%3, %1, %234)
%5: int = aten::Int(%4)
return (%5))IR";

std::string target_graph = R"IR(
graph(%0 : int):
%1 : str = prim::Constant[value="trunc"]()
%2 : int = prim::Constant[value=8]()
%3 : float = aten::div(%0, %2)
%4 : int = aten::Int(%3)
return (%4))IR";

auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);

auto first_op = *(sg->block()->nodes().begin());
torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op->scope());
r->copyMetadata(first_op->output());
r->setType(c10::TensorType::get());
first_op->output()->replaceAllUsesWith(r);
first_op->destroy();

torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg);
torch::jit::ConstantPooling(sg);
sg = torch::jit::Canonicalize(sg, false);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);
torch::jit::ConstantPooling(tg);
tg = torch::jit::Canonicalize(tg, false);

// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == sg->toString()));
}

TEST(LoweringPasses, RemoveSingleUse0DTensorsDivFloorIntValuesAgree) {
// Ensure the source and target graphs have equivalent outputs
// (Source and Target are computing equivalent values)
std::string source_graph_no_inputs = R"IR(
graph():
%0: int = prim::Constant[value=2]()
%11: int = prim::Constant[value=-3]()
%234 : str = prim::Constant[value="floor"]()
%3: Tensor = prim::NumToTensor(%0)
%1: Tensor = prim::NumToTensor(%11)
%4: Tensor = aten::div(%1, %3, %234)
%50: int = aten::Int(%4)
%5: Tensor = prim::NumToTensor(%50)
return (%5))IR";
std::string target_graph_no_inputs = R"IR(
graph():
%0: int = prim::Constant[value=2]()
%1: int = prim::Constant[value=-3]()
%40: int = aten::floordiv(%1, %0)
%41: int = aten::Int(%40)
%4: Tensor = prim::NumToTensor(%41)
return (%4))IR";

auto g_in = std::make_shared<torch::jit::Graph>();
auto g_out = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(source_graph_no_inputs, g_in.get());
torch::jit::parseIR(target_graph_no_inputs, g_out.get());

auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {});
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {});

ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor()));

// Ensure the lowering pass transforms the first graph into the second
std::string source_graph = R"IR(
graph(%0: int):
%1: Tensor = prim::Constant[value=[8]]()
%3: Tensor = prim::NumToTensor(%0)
%234: str = prim::Constant[value="floor"]()
%4: Tensor = aten::div(%3, %1, %234)
%5: int = aten::Int(%4)
return (%5))IR";

std::string target_graph = R"IR(
graph(%0 : int):
%1 : str = prim::Constant[value="floor"]()
%2 : int = prim::Constant[value=8]()
%3 : int = aten::floordiv(%0, %2)
return (%3))IR";

auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);

auto first_op = *(sg->block()->nodes().begin());
torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op->scope());
r->copyMetadata(first_op->output());
r->setType(c10::TensorType::get());
first_op->output()->replaceAllUsesWith(r);
first_op->destroy();

torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg);
torch::jit::ConstantPooling(sg);
sg = torch::jit::Canonicalize(sg, false);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);
torch::jit::ConstantPooling(tg);
tg = torch::jit::Canonicalize(tg, false);

// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == sg->toString()));
}

TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatValuesAgree) {
std::string source_graph_no_inputs = R"IR(
graph():
Expand Down