Skip to content

Commit

Permalink
fix/feat: Add lowering pass to resolve most aten::Int.Tensor uses (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive authored and narendasan committed Jun 3, 2023
1 parent ae564d7 commit f67dd1d
Show file tree
Hide file tree
Showing 4 changed files with 299 additions and 0 deletions.
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
passes::SiluToSigmoidMultipication(g);
passes::RemoveSingleUse0DTensors(g);
passes::RemoveUnnecessaryCasts(g);
passes::ReplaceAtenInt(g);
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g);
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
void ReplaceAtenInt(std::shared_ptr<torch::jit::Graph>& g);
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down
145 changes: 145 additions & 0 deletions core/lowering/passes/remove_unnecessary_casts.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "torch/csrc/jit/ir/constants.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"
Expand Down Expand Up @@ -211,6 +212,150 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
LOG_GRAPH("Post removing single use 0-dim Tensor operations: " << *g);
}

// Schemas for Aten::Int which can be replaced by scalar equivalents
const std::unordered_set<c10::Symbol> AtenIntReplacementNodeKinds = {
torch::jit::aten::mul,
torch::jit::aten::floor_divide,
};

c10::optional<torch::jit::Value*> Validate0DTensor(torch::jit::Value* value) {
// Validates that the input Value* is a 0D Tensor (or int/float)
// Return the stored int/float Value* if so, otherwise null
c10::optional<torch::jit::Value*> enclosed_scalar_value = {};

// Regular Int/Float case
if (value->type()->isSubtypeOf(c10::IntType::get()) || value->type()->isSubtypeOf(c10::FloatType::get())) {
enclosed_scalar_value = value;
return enclosed_scalar_value;
}

// Constant Tensor case
if (value->node()->kind() == torch::jit::prim::Constant && value->type()->isSubtypeOf(c10::TensorType::get())) {
// Retrieve the Tensor stored in constant
at::Tensor t = *torch::jit::constant_as<at::Tensor>(value);
// Validate the shape of the Tensor is 0D (single-element) and integral
if (t.sizes() == std::vector<int64_t>({}) && t.item().isIntegral(false)) {
// Extract the stored value, add it to the graph as a constant
torch::jit::WithInsertPoint guard(value->node());
auto new_const_val = value->owningGraph()->insertConstant(t.item(), c10::nullopt, value->node()->scope());
new_const_val->copyMetadata(value);
new_const_val->setType(c10::IntType::get());
enclosed_scalar_value = new_const_val;
return enclosed_scalar_value;
} else {
LOG_DEBUG("In aten::Int.Tensor removal, encountered a const which was either not 0D or not integral");
}
}

// NumToTensor case
if (value->node()->kind() == torch::jit::prim::NumToTensor && value->type()->isSubtypeOf(c10::TensorType::get())) {
// Input to NumToTensor is relevant scalar
enclosed_scalar_value = value->node()->input();
return enclosed_scalar_value;
}

return enclosed_scalar_value;
}

c10::optional<torch::jit::Value*> TracebackAndEliminate0DTensors(torch::jit::Node* node) {
// Trace back through a node and all parents to eliminate 0D Tensors
// and update schemas to their scalar alternatives, returning final
// Value* to user

// Requires valid schema with at least two inputs
if (AtenIntReplacementNodeKinds.find(node->kind()) == AtenIntReplacementNodeKinds.end() ||
node->inputs().size() < 2) {
LOG_DEBUG(
"Encountered node " << node->kind().toQualString()
<< " which is unsupported in the aten::Int.Tensor replacement lowering pass.");
return {};
}

// Validate the first and second function inputs are 0D tensors or scalars
c10::optional<torch::jit::Value*> first_input_scalar_value = Validate0DTensor(node->inputs()[0]);
c10::optional<torch::jit::Value*> second_input_scalar_value = Validate0DTensor(node->inputs()[1]);

// If the first input is not a scalar, recursively traceback on parent nodes
if (!first_input_scalar_value.has_value()) {
LOG_DEBUG("In aten::Int.Tensor lowering, now tracing " << node->inputs()[0]->node()->kind().toQualString());
first_input_scalar_value = TracebackAndEliminate0DTensors(node->inputs()[0]->node());
}

// If the second input is not a scalar, recursively traceback on parent nodes
if (!second_input_scalar_value.has_value()) {
LOG_DEBUG("In aten::Int.Tensor lowering, now tracing " << node->inputs()[0]->node()->kind().toQualString());
second_input_scalar_value = TracebackAndEliminate0DTensors(node->inputs()[1]->node());
}

if (!first_input_scalar_value.has_value() || !second_input_scalar_value.has_value()) {
LOG_DEBUG(
"In aten::Int.Tensor lowering, recursive trace through node input "
<< "parents failed to return a Scalar value for at least one parent node.");
return {};
}

// Set default insert point at node
torch::jit::WithInsertPoint guard(node);
torch::jit::Node* new_node;

switch (node->kind()) {
// In the aten::floor_divide case, the schema syntax changes, so a new node
// must be inserted
case torch::jit::aten::floor_divide:
new_node = node->owningGraph()->create(
torch::jit::aten::floordiv, {first_input_scalar_value.value(), second_input_scalar_value.value()}, 1);
new_node->insertAfter(node);
new_node->output()->setType(c10::IntType::get());
return new_node->output();

// In the aten::mul case, the schema syntax is the same, so we can use the existing schema
// with new inputs
default:
new_node = node->owningGraph()->create(
node->kind(), {first_input_scalar_value.value(), second_input_scalar_value.value()}, 1);
new_node->insertAfter(node);
new_node->output()->setType(c10::IntType::get());
return new_node->output();
}
}

void ReplaceAtenInt(std::shared_ptr<torch::jit::Graph>& g) {
// Find all nodes with the aten::Int.Tensor schema and replace those
// by tracing through the node and resolving the use of 0D tensors
// to their corresponding scalar alternatives

// Iterate over all nodes in the graph
for (auto it = g->block()->nodes().begin(), end = g->block()->nodes().end(); it != end; ++it) {
// Validate schema requirements for aten::Int.Tensor
if (it->kind() == torch::jit::aten::Int && it->inputs().size() == 1 &&
it->input()->type()->isSubtypeOf(c10::TensorType::get())) {
LOG_DEBUG("Found an aten::Int.Tensor case, attempting to resolve input scalars.");

// If the node parent schema is of a supported type, trace back through the graph
if (AtenIntReplacementNodeKinds.find(it->input()->node()->kind()) != AtenIntReplacementNodeKinds.end()) {
LOG_DEBUG(
"Tracing parent node " << it->input()->node()->kind().toQualString()
<< " to eliminate 0D Tensors for aten::Int.Tensor case.");
auto scalar_input_value = TracebackAndEliminate0DTensors(it->input()->node());
if (scalar_input_value.has_value()) {
it->output()->replaceAllUsesWith(scalar_input_value.value());
LOG_DEBUG("Tracing parent nodes for aten::Int.Tensor case succeeded.");
} else {
LOG_DEBUG("Tracing parent nodes for aten::Int.Tensor case failed.");
}
} else {
LOG_DEBUG(
"Parent node schema " << it->input()->node()->kind().toQualString()
<< " is currently unsupported for aten::Int.Tensor case.");
}
}
}

// Clean up remnant operators in graph
torch::jit::EliminateDeadCode(g);
LOG_GRAPH("Post removing aten.Int.Tensor operations: " << *g);
}

} // namespace passes
} // namespace lowering
} // namespace core
Expand Down
152 changes: 152 additions & 0 deletions tests/core/lowering/test_remove_unnecessary_casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,155 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatValuesAgree) {
ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
}

TEST(LoweringPasses, RemoveAtenIntTensorValuesAgree) {
std::string source_graph_no_inputs = R"IR(
graph():
%0: int = prim::Constant[value=2]()
%11: int = prim::Constant[value=7]()
%3: Tensor = prim::NumToTensor(%0)
%1: Tensor = prim::NumToTensor(%11)
%4: Tensor = aten::floor_divide(%1, %3)
%7: Tensor = aten::mul(%3, %4)
%8: Tensor = aten::mul(%7, %1)
%50: int = aten::Int(%8)
%5: Tensor = prim::NumToTensor(%50)
return (%5))IR";
std::string target_graph_no_inputs = R"IR(
graph():
%0: int = prim::Constant[value=2]()
%1: int = prim::Constant[value=7]()
%4: int = aten::floordiv(%1, %0)
%7: int = aten::mul(%0, %4)
%40: int = aten::mul(%7, %1)
%4: Tensor = prim::NumToTensor(%40)
return (%4))IR";

auto g_in = std::make_shared<torch::jit::Graph>();
auto g_out = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(source_graph_no_inputs, g_in.get());
torch::jit::parseIR(target_graph_no_inputs, g_out.get());

auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {});
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));

// Ensure the lowering pass transforms the first graph into the second
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph_no_inputs, sg.get());

torch_tensorrt::core::lowering::passes::ReplaceAtenInt(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph_no_inputs, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveAtenIntSizeTensorValuesAgree) {
std::string source_graph_no_inputs = R"IR(
graph(%x.0: Tensor):
%10: int = prim::Constant[value=0]()
%100: int = aten::size(%x.0, %10)
%0: Tensor = prim::NumToTensor(%100)
%11: int = prim::Constant[value=9]()
%1: Tensor = prim::NumToTensor(%11)
%4: Tensor = aten::floor_divide(%1, %0)
%7: Tensor = aten::mul(%0, %4)
%8: Tensor = aten::mul(%7, %1)
%50: int = aten::Int(%8)
%5: Tensor = prim::NumToTensor(%50)
return (%5))IR";
std::string target_graph_no_inputs = R"IR(
graph(%x.0: Tensor):
%10: int = prim::Constant[value=0]()
%0: int = aten::size(%x.0, %10)
%1: int = prim::Constant[value=9]()
%4: int = aten::floordiv(%1, %0)
%7: int = aten::mul(%0, %4)
%40: int = aten::mul(%7, %1)
%4: Tensor = prim::NumToTensor(%40)
return (%4))IR";

auto g_in = std::make_shared<torch::jit::Graph>();
auto g_out = std::make_shared<torch::jit::Graph>();

auto in_0 = at::rand({2, 3, 5, 5}, {at::kCUDA});

torch::jit::parseIR(source_graph_no_inputs, g_in.get());
torch::jit::parseIR(target_graph_no_inputs, g_out.get());

auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {in_0});
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {in_0});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));

// Ensure the lowering pass transforms the first graph into the second
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph_no_inputs, sg.get());

torch_tensorrt::core::lowering::passes::ReplaceAtenInt(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph_no_inputs, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveAtenIntConstTensorValuesAgree) {
// Ensure the lowering pass transforms the first graph into the second
std::string source_graph = R"IR(
graph(%0: int):
%1: Tensor = prim::Constant[value=[8]]()
%3: Tensor = prim::NumToTensor(%0)
%4: Tensor = aten::floor_divide(%3, %1)
%5: int = aten::Int(%4)
return (%5))IR";

std::string target_graph = R"IR(
graph(%0 : int):
%1 : Tensor = prim::Constant[value=[8]]()
%2 : int = prim::Constant[value=8]()
%3 : int = aten::floordiv(%0, %2)
return (%3))IR";

auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);

// Manually enter 0d tensor const for source
auto first_op_sg = *(sg->block()->nodes().begin());
torch::jit::Value* r_sg = sg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op_sg->scope());
r_sg->copyMetadata(first_op_sg->output());
r_sg->setType(c10::TensorType::get());
first_op_sg->output()->replaceAllUsesWith(r_sg);
first_op_sg->destroy();

torch_tensorrt::core::lowering::passes::ReplaceAtenInt(sg);
torch::jit::ConstantPooling(sg);
sg = torch::jit::Canonicalize(sg, false);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);

// Manually enter 0d tensor const for target
auto first_op_tg = *(tg->block()->nodes().begin());
torch::jit::Value* r_tg = tg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op_tg->scope());
r_tg->copyMetadata(first_op_tg->output());
r_tg->setType(c10::TensorType::get());
first_op_tg->output()->replaceAllUsesWith(r_tg);
first_op_tg->destroy();

torch::jit::ConstantPooling(tg);
tg = torch::jit::Canonicalize(tg, false);

// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == sg->toString()));
}

0 comments on commit f67dd1d

Please sign in to comment.