diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index fa179a3922..4184b2f6be 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -143,6 +143,7 @@ void LowerGraph(std::shared_ptr& g, std::vector& graph); void RemoveNOPs(std::shared_ptr graph); void RemoveSingleUse0DTensors(std::shared_ptr& g); void RemoveUnnecessaryCasts(std::shared_ptr& graph); +void ReplaceAtenInt(std::shared_ptr& g); void UnpackAddMM(std::shared_ptr& graph); void UnpackBatchNorm(std::shared_ptr& graph); void UnpackLogSoftmax(std::shared_ptr& graph); diff --git a/core/lowering/passes/remove_unnecessary_casts.cpp b/core/lowering/passes/remove_unnecessary_casts.cpp index 451e77238e..672c30409d 100644 --- a/core/lowering/passes/remove_unnecessary_casts.cpp +++ b/core/lowering/passes/remove_unnecessary_casts.cpp @@ -1,4 +1,5 @@ #include "torch/csrc/jit/ir/constants.h" +#include "torch/csrc/jit/passes/dead_code_elimination.h" #include "torch/csrc/jit/passes/subgraph_rewrite.h" #include "core/util/prelude.h" @@ -211,6 +212,150 @@ void RemoveSingleUse0DTensors(std::shared_ptr& g) { LOG_GRAPH("Post removing single use 0-dim Tensor operations: " << *g); } +// Schemas for Aten::Int which can be replaced by scalar equivalents +const std::unordered_set AtenIntReplacementNodeKinds = { + torch::jit::aten::mul, + torch::jit::aten::floor_divide, +}; + +c10::optional Validate0DTensor(torch::jit::Value* value) { + // Validates that the input Value* is a 0D Tensor (or int/float) + // Return the stored int/float Value* if so, otherwise null + c10::optional enclosed_scalar_value = {}; + + // Regular Int/Float case + if (value->type()->isSubtypeOf(c10::IntType::get()) || value->type()->isSubtypeOf(c10::FloatType::get())) { + enclosed_scalar_value = value; + return enclosed_scalar_value; + } + + // Constant Tensor case + if (value->node()->kind() == torch::jit::prim::Constant && value->type()->isSubtypeOf(c10::TensorType::get())) { + // Retrieve the Tensor stored in constant + at::Tensor t = *torch::jit::constant_as(value); + // Validate the shape of the Tensor is 0D (single-element) and integral + if (t.sizes() == std::vector({}) && t.item().isIntegral(false)) { + // Extract the stored value, add it to the graph as a constant + torch::jit::WithInsertPoint guard(value->node()); + auto new_const_val = value->owningGraph()->insertConstant(t.item(), c10::nullopt, value->node()->scope()); + new_const_val->copyMetadata(value); + new_const_val->setType(c10::IntType::get()); + enclosed_scalar_value = new_const_val; + return enclosed_scalar_value; + } else { + LOG_DEBUG("In aten::Int.Tensor removal, encountered a const which was either not 0D or not integral"); + } + } + + // NumToTensor case + if (value->node()->kind() == torch::jit::prim::NumToTensor && value->type()->isSubtypeOf(c10::TensorType::get())) { + // Input to NumToTensor is relevant scalar + enclosed_scalar_value = value->node()->input(); + return enclosed_scalar_value; + } + + return enclosed_scalar_value; +} + +c10::optional TracebackAndEliminate0DTensors(torch::jit::Node* node) { + // Trace back through a node and all parents to eliminate 0D Tensors + // and update schemas to their scalar alternatives, returning final + // Value* to user + + // Requires valid schema with at least two inputs + if (AtenIntReplacementNodeKinds.find(node->kind()) == AtenIntReplacementNodeKinds.end() || + node->inputs().size() < 2) { + LOG_DEBUG( + "Encountered node " << node->kind().toQualString() + << " which is unsupported in the aten::Int.Tensor replacement lowering pass."); + return {}; + } + + // Validate the first and second function inputs are 0D tensors or scalars + c10::optional first_input_scalar_value = Validate0DTensor(node->inputs()[0]); + c10::optional second_input_scalar_value = Validate0DTensor(node->inputs()[1]); + + // If the first input is not a scalar, recursively traceback on parent nodes + if (!first_input_scalar_value.has_value()) { + LOG_DEBUG("In aten::Int.Tensor lowering, now tracing " << node->inputs()[0]->node()->kind().toQualString()); + first_input_scalar_value = TracebackAndEliminate0DTensors(node->inputs()[0]->node()); + } + + // If the second input is not a scalar, recursively traceback on parent nodes + if (!second_input_scalar_value.has_value()) { + LOG_DEBUG("In aten::Int.Tensor lowering, now tracing " << node->inputs()[0]->node()->kind().toQualString()); + second_input_scalar_value = TracebackAndEliminate0DTensors(node->inputs()[1]->node()); + } + + if (!first_input_scalar_value.has_value() || !second_input_scalar_value.has_value()) { + LOG_DEBUG( + "In aten::Int.Tensor lowering, recursive trace through node input " + << "parents failed to return a Scalar value for at least one parent node."); + return {}; + } + + // Set default insert point at node + torch::jit::WithInsertPoint guard(node); + torch::jit::Node* new_node; + + switch (node->kind()) { + // In the aten::floor_divide case, the schema syntax changes, so a new node + // must be inserted + case torch::jit::aten::floor_divide: + new_node = node->owningGraph()->create( + torch::jit::aten::floordiv, {first_input_scalar_value.value(), second_input_scalar_value.value()}, 1); + new_node->insertAfter(node); + new_node->output()->setType(c10::IntType::get()); + return new_node->output(); + + // In the aten::mul case, the schema syntax is the same, so we can use the existing schema + // with new inputs + default: + new_node = node->owningGraph()->create( + node->kind(), {first_input_scalar_value.value(), second_input_scalar_value.value()}, 1); + new_node->insertAfter(node); + new_node->output()->setType(c10::IntType::get()); + return new_node->output(); + } +} + +void ReplaceAtenInt(std::shared_ptr& g) { + // Find all nodes with the aten::Int.Tensor schema and replace those + // by tracing through the node and resolving the use of 0D tensors + // to their corresponding scalar alternatives + + // Iterate over all nodes in the graph + for (auto it = g->block()->nodes().begin(), end = g->block()->nodes().end(); it != end; ++it) { + // Validate schema requirements for aten::Int.Tensor + if (it->kind() == torch::jit::aten::Int && it->inputs().size() == 1 && + it->input()->type()->isSubtypeOf(c10::TensorType::get())) { + LOG_DEBUG("Found an aten::Int.Tensor case, attempting to resolve input scalars."); + + // If the node parent schema is of a supported type, trace back through the graph + if (AtenIntReplacementNodeKinds.find(it->input()->node()->kind()) != AtenIntReplacementNodeKinds.end()) { + LOG_DEBUG( + "Tracing parent node " << it->input()->node()->kind().toQualString() + << " to eliminate 0D Tensors for aten::Int.Tensor case."); + auto scalar_input_value = TracebackAndEliminate0DTensors(it->input()->node()); + if (scalar_input_value.has_value()) { + it->output()->replaceAllUsesWith(scalar_input_value.value()); + LOG_DEBUG("Tracing parent nodes for aten::Int.Tensor case succeeded."); + } else { + LOG_DEBUG("Tracing parent nodes for aten::Int.Tensor case failed."); + } + } else { + LOG_DEBUG( + "Parent node schema " << it->input()->node()->kind().toQualString() + << " is currently unsupported for aten::Int.Tensor case."); + } + } + } + + // Clean up remnant operators in graph + torch::jit::EliminateDeadCode(g); + LOG_GRAPH("Post removing aten.Int.Tensor operations: " << *g); +} + } // namespace passes } // namespace lowering } // namespace core diff --git a/tests/core/lowering/test_remove_unnecessary_casts.cpp b/tests/core/lowering/test_remove_unnecessary_casts.cpp index 704b2064ea..488d7988ea 100644 --- a/tests/core/lowering/test_remove_unnecessary_casts.cpp +++ b/tests/core/lowering/test_remove_unnecessary_casts.cpp @@ -437,3 +437,155 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatValuesAgree) { ASSERT_TRUE( torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); } + +TEST(LoweringPasses, RemoveAtenIntTensorValuesAgree) { + std::string source_graph_no_inputs = R"IR( + graph(): + %0: int = prim::Constant[value=2]() + %11: int = prim::Constant[value=7]() + %3: Tensor = prim::NumToTensor(%0) + %1: Tensor = prim::NumToTensor(%11) + %4: Tensor = aten::floor_divide(%1, %3) + %7: Tensor = aten::mul(%3, %4) + %8: Tensor = aten::mul(%7, %1) + %50: int = aten::Int(%8) + %5: Tensor = prim::NumToTensor(%50) + return (%5))IR"; + std::string target_graph_no_inputs = R"IR( + graph(): + %0: int = prim::Constant[value=2]() + %1: int = prim::Constant[value=7]() + %4: int = aten::floordiv(%1, %0) + %7: int = aten::mul(%0, %4) + %40: int = aten::mul(%7, %1) + %4: Tensor = prim::NumToTensor(%40) + return (%4))IR"; + + auto g_in = std::make_shared(); + auto g_out = std::make_shared(); + + torch::jit::parseIR(source_graph_no_inputs, g_in.get()); + torch::jit::parseIR(target_graph_no_inputs, g_out.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {}); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); + + // Ensure the lowering pass transforms the first graph into the second + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph_no_inputs, sg.get()); + + torch_tensorrt::core::lowering::passes::ReplaceAtenInt(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph_no_inputs, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} + +TEST(LoweringPasses, RemoveAtenIntSizeTensorValuesAgree) { + std::string source_graph_no_inputs = R"IR( + graph(%x.0: Tensor): + %10: int = prim::Constant[value=0]() + %100: int = aten::size(%x.0, %10) + %0: Tensor = prim::NumToTensor(%100) + %11: int = prim::Constant[value=9]() + %1: Tensor = prim::NumToTensor(%11) + %4: Tensor = aten::floor_divide(%1, %0) + %7: Tensor = aten::mul(%0, %4) + %8: Tensor = aten::mul(%7, %1) + %50: int = aten::Int(%8) + %5: Tensor = prim::NumToTensor(%50) + return (%5))IR"; + std::string target_graph_no_inputs = R"IR( + graph(%x.0: Tensor): + %10: int = prim::Constant[value=0]() + %0: int = aten::size(%x.0, %10) + %1: int = prim::Constant[value=9]() + %4: int = aten::floordiv(%1, %0) + %7: int = aten::mul(%0, %4) + %40: int = aten::mul(%7, %1) + %4: Tensor = prim::NumToTensor(%40) + return (%4))IR"; + + auto g_in = std::make_shared(); + auto g_out = std::make_shared(); + + auto in_0 = at::rand({2, 3, 5, 5}, {at::kCUDA}); + + torch::jit::parseIR(source_graph_no_inputs, g_in.get()); + torch::jit::parseIR(target_graph_no_inputs, g_out.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {in_0}); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {in_0}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); + + // Ensure the lowering pass transforms the first graph into the second + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph_no_inputs, sg.get()); + + torch_tensorrt::core::lowering::passes::ReplaceAtenInt(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph_no_inputs, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} + +TEST(LoweringPasses, RemoveAtenIntConstTensorValuesAgree) { + // Ensure the lowering pass transforms the first graph into the second + std::string source_graph = R"IR( + graph(%0: int): + %1: Tensor = prim::Constant[value=[8]]() + %3: Tensor = prim::NumToTensor(%0) + %4: Tensor = aten::floor_divide(%3, %1) + %5: int = aten::Int(%4) + return (%5))IR"; + + std::string target_graph = R"IR( + graph(%0 : int): + %1 : Tensor = prim::Constant[value=[8]]() + %2 : int = prim::Constant[value=8]() + %3 : int = aten::floordiv(%0, %2) + return (%3))IR"; + + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + + // Manually enter 0d tensor const for source + auto first_op_sg = *(sg->block()->nodes().begin()); + torch::jit::Value* r_sg = sg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op_sg->scope()); + r_sg->copyMetadata(first_op_sg->output()); + r_sg->setType(c10::TensorType::get()); + first_op_sg->output()->replaceAllUsesWith(r_sg); + first_op_sg->destroy(); + + torch_tensorrt::core::lowering::passes::ReplaceAtenInt(sg); + torch::jit::ConstantPooling(sg); + sg = torch::jit::Canonicalize(sg, false); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + // Manually enter 0d tensor const for target + auto first_op_tg = *(tg->block()->nodes().begin()); + torch::jit::Value* r_tg = tg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op_tg->scope()); + r_tg->copyMetadata(first_op_tg->output()); + r_tg->setType(c10::TensorType::get()); + first_op_tg->output()->replaceAllUsesWith(r_tg); + first_op_tg->destroy(); + + torch::jit::ConstantPooling(tg); + tg = torch::jit::Canonicalize(tg, false); + + // Validate identical graphs after pooling constants and canonicalizing + ASSERT_TRUE((tg->toString() == sg->toString())); +}