From 2866627c8fb43202039eeb6eebe78c32152cec41 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 28 Jul 2021 00:54:19 -0700 Subject: [PATCH] feat(aten::std|aten::masked_fill): Implement masked_fill, aten::std works for non bias corrected cases Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/conversion.cpp | 6 + .../converters/impl/element_wise.cpp | 24 ++ core/conversion/converters/impl/reduce.cpp | 10 +- core/conversion/converters/impl/select.cpp | 34 ++- core/conversion/evaluators/aten.cpp | 4 +- core/lowering/lowering.cpp | 2 + core/lowering/passes/BUILD | 4 +- core/lowering/passes/passes.h | 2 + core/lowering/passes/unpack_std.cpp | 30 +++ core/lowering/passes/unpack_var.cpp | 51 +++++ core/util/trt_util.h | 2 - .../conversion/converters/test_reduce.cpp | 214 +++++++++++++++++- .../conversion/converters/test_select.cpp | 1 - tests/core/lowering/BUILD | 7 +- .../core/lowering/test_unpack_reduce_ops.cpp | 198 ++++++++++++++++ tests/util/util.cpp | 3 + 16 files changed, 561 insertions(+), 31 deletions(-) create mode 100644 core/lowering/passes/unpack_std.cpp create mode 100644 core/lowering/passes/unpack_var.cpp create mode 100644 tests/core/lowering/test_unpack_reduce_ops.cpp diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 0c7d82cb46..7dab6d3397 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -87,6 +87,9 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) { if (eval) { if (!eval.value().isTensor()) { LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value()); + if (eval.value().isTuple() && eval.value().toTuple()->elements().size() == 1) { + eval.value() = {eval.value().toTuple()->elements()[0]}; + } } else { LOG_DEBUG(ctx->logger, "Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')'); } @@ -283,6 +286,9 @@ void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n, boo auto eval = EvaluateNode(ctx, bn); if (!eval.value().isTensor()) { LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be: " << eval.value()); + if (eval.value().isTuple() && eval.value().toTuple()->elements().size() == 1) { + eval.value() = {eval.value().toTuple()->elements()[0]}; + } } else { LOG_DEBUG( ctx->logger, diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index 568e8840a4..5008e794c7 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -185,6 +185,30 @@ auto element_wise_registrations TRTORCH_UNUSED = LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) + .pattern({"aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // Should implement self - alpha * other + auto self = args[0].ITensorOrFreeze(ctx); + auto other = args[1].unwrapToScalar().to(); + auto alpha = args[2].unwrapToScalar().to(); + + auto rhs = other * alpha; + if (1 != rhs) { + auto rhs_tensor = tensor_to_const(ctx, torch::tensor({rhs})); + auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, rhs_tensor, util::node_info(n)); + TRTORCH_CHECK(sub, "Unable to create sub layer from node: " << *n); + sub->setName(util::node_info(n).c_str()); + LOG_DEBUG("Output tensor shape: " << sub->getOutput(0)->getDimensions()); + ctx->AssociateValueAndTensor(n->outputs()[0], sub->getOutput(0)); + return true; + } else { + LOG_DEBUG("Nothing to be done this layer, passing through input"); + LOG_DEBUG("Output tensor shape: " << self->getDimensions()); + + ctx->AssociateValueAndTensor(n->outputs()[0], self); + return true; + } + }}) .pattern({"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar " "alpha=1) -> (Tensor(a!))", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { diff --git a/core/conversion/converters/impl/reduce.cpp b/core/conversion/converters/impl/reduce.cpp index 101e9aca89..534267f3b5 100644 --- a/core/conversion/converters/impl/reduce.cpp +++ b/core/conversion/converters/impl/reduce.cpp @@ -106,10 +106,10 @@ auto reduce_registrations TRTORCH_UNUSED = for (size_t d = 0; d < calculated_dims.size(); d++) { axis_mask |= 1 << calculated_dims[d]; } - LOG_DEBUG("Axis Mask" << std::bitset<32>(axis_mask)); + LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask)); auto keepdim = args[2].unwrapToBool(); - LOG_DEBUG("Keep dims :" << keepdim); + LOG_DEBUG("Keep dims: " << keepdim); LOG_WARNING("Sum converter disregards dtype"); auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim); @@ -145,13 +145,13 @@ auto reduce_registrations TRTORCH_UNUSED = [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in_tensor = args[0].ITensorOrFreeze(ctx); auto dim = args[1].unwrapToInt(); - LOG_DEBUG("Dim to reduce:" << dim); // Some abuse of toDim but just for debug info + LOG_DEBUG("Dim to reduce: " << dim); // Some abuse of toDim but just for debug info uint32_t axis_mask = 1 << dim; - LOG_DEBUG("Axis Mask" << std::bitset<32>(axis_mask)); + LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask)); auto keepdim = args[2].unwrapToBool(); - LOG_DEBUG("Keep dims :" << keepdim); + LOG_DEBUG("Keep dims: " << keepdim); LOG_WARNING("Prod converter disregards dtype"); auto prod_layer = diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 9b4e39c2b5..9863dd59c4 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -71,7 +71,7 @@ auto select_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() .pattern({"aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); + auto in = args[0].ITensorOrFreeze(ctx); auto maxDim = static_cast(in->getDimensions().nbDims); auto axis = args[1].unwrapToInt(); axis = axis < 0 ? axis + maxDim : axis; @@ -79,27 +79,26 @@ auto select_registrations TRTORCH_UNUSED = // index to access needs to be an at::Tensor at::Tensor indices = torch::tensor({ind}).to(torch::kI32); - auto weights = Weights(ctx, indices); - - // IConstantLayer to convert indices from Weights to ITensor - auto const_layer = ctx->net->addConstant(weights.shape, weights.data); - TRTORCH_CHECK(const_layer, "Unable to create constant layer from node: " << *n); - auto const_out = const_layer->getOutput(0); + auto const_out = tensor_to_const(ctx, indices); // IGatherLayer takes in input tensor, the indices, and the axis // of input tensor to take indices from auto gather_layer = ctx->net->addGather(*in, *const_out, axis); TRTORCH_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); - auto gather_out = gather_layer->getOutput(0); + auto out = gather_layer->getOutput(0); - // IShuffleLayer removes redundant dimensions - auto shuffle_layer = ctx->net->addShuffle(*gather_out); - TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); - shuffle_layer->setReshapeDimensions(util::squeezeDims(gather_out->getDimensions(), axis)); - shuffle_layer->setName(util::node_info(n).c_str()); - auto shuffle_out = shuffle_layer->getOutput(0); + LOG_DEBUG("Gather tensor shape: " << out->getDimensions()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_out); + if (out->getDimensions().nbDims != 1) { + // IShuffleLayer removes redundant dimensions + auto shuffle_layer = ctx->net->addShuffle(*out); + TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); + shuffle_layer->setReshapeDimensions(util::squeezeDims(out->getDimensions(), axis)); + shuffle_layer->setName(util::node_info(n).c_str()); + out = shuffle_layer->getOutput(0); + } + + out = ctx->AssociateValueAndTensor(n->outputs()[0], out); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); @@ -253,15 +252,14 @@ auto select_registrations TRTORCH_UNUSED = "aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); - LOG_DEBUG(args[1].unwrapToTensor()); auto mask = castITensor(ctx, args[1].ITensorOrFreeze(ctx), nvinfer1::DataType::kBOOL); + mask = addPadding(ctx, n, mask, self->getDimensions().nbDims, false, true); auto val = args[2].unwrapToScalar().to(); - LOG_DEBUG(torch::full(util::toVec(self->getDimensions()), val)); auto val_t = tensor_to_const(ctx, torch::full(util::toVec(self->getDimensions()), val)); TRTORCH_CHECK(util::broadcastable(self->getDimensions(), mask->getDimensions(), /*multidirectional=*/false), "Self and mask tensors are not broadcastable"); - auto new_layer = ctx->net->addSelect(*mask, *self, *val_t); + auto new_layer = ctx->net->addSelect(*mask, *val_t, *self); TRTORCH_CHECK(new_layer, "Unable to create layer for aten::masked_fill"); new_layer->setName(util::node_info(n).c_str()); diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 8e367f9779..a119925fa6 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -573,10 +573,10 @@ auto aten_registrations TRTORCH_UNUSED = auto dtype = args.at(n->input(1)).IValue(); auto device = args.at(n->input(2)).IValue(); auto tensor = createTensorFromList(*data, *dtype, *device); - LOG_DEBUG(tensor); if (tensor.dtype() == at::kByte) { - return tensor.to(at::kInt); + return tensor.to(at::kFloat); } + std::cout << tensor << std::endl; return tensor; }, EvalOptions().validSchemas({ diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 70439fb127..41f555b95b 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -48,6 +48,8 @@ void LowerGraph(std::shared_ptr& g) { passes::UnpackAddMM(g); // passes::UnpackBatchNorm(g); passes::UnpackLogSoftmax(g); + passes::UnpackStd(g); + passes::UnpackVar(g); passes::RemoveNOPs(g); passes::AliasOperators(g); passes::SiluToSigmoidMultipication(g); diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index d615af06bc..8a7c62747c 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -24,7 +24,9 @@ cc_library( "unpack_addmm.cpp", "unpack_batch_norm.cpp", "unpack_log_softmax.cpp", - "unpack_hardswish.cpp" + "unpack_hardswish.cpp", + "unpack_std.cpp", + "unpack_var.cpp", ], hdrs = [ "passes.h", diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index db2bcb2ebe..c7aab97cad 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -19,6 +19,8 @@ void RemoveNOPs(std::shared_ptr graph); void UnpackAddMM(std::shared_ptr& graph); void UnpackBatchNorm(std::shared_ptr& graph); void UnpackLogSoftmax(std::shared_ptr& graph); +void UnpackStd(std::shared_ptr& graph); +void UnpackVar(std::shared_ptr& graph); void AliasOperators(std::shared_ptr& graph); void SiluToSigmoidMultipication(std::shared_ptr& graph); void UnpackHardSwish(std::shared_ptr& graph); diff --git a/core/lowering/passes/unpack_std.cpp b/core/lowering/passes/unpack_std.cpp new file mode 100644 index 0000000000..4ba8d9a02d --- /dev/null +++ b/core/lowering/passes/unpack_std.cpp @@ -0,0 +1,30 @@ +#include "torch/csrc/jit/passes/subgraph_rewrite.h" + +#include "core/util/prelude.h" + +namespace trtorch { +namespace core { +namespace lowering { +namespace passes { + +void UnpackStd(std::shared_ptr& graph) { + std::string std_pattern = R"IR( + graph(%1, %dim, %unbiased, %keepdim): + %out: Tensor = aten::std(%1, %dim, %unbiased, %keepdim) + return (%out))IR"; + std::string unpacked_pattern = R"IR( + graph(%1, %dim, %unbiased, %keepdim): + %z: Tensor = aten::var(%1, %dim, %unbiased, %keepdim) + %out: Tensor = aten::sqrt(%z) + return (%out))IR"; + + torch::jit::SubgraphRewriter std_rewriter; + std_rewriter.RegisterRewritePattern(std_pattern, unpacked_pattern); + std_rewriter.runOnGraph(graph); + LOG_GRAPH("Post unpack std: " << *graph); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace trtorch diff --git a/core/lowering/passes/unpack_var.cpp b/core/lowering/passes/unpack_var.cpp new file mode 100644 index 0000000000..48a1b7232c --- /dev/null +++ b/core/lowering/passes/unpack_var.cpp @@ -0,0 +1,51 @@ +#include "torch/csrc/jit/passes/subgraph_rewrite.h" + +#include "core/util/prelude.h" + +namespace trtorch { +namespace core { +namespace lowering { +namespace passes { + +void UnpackVar(std::shared_ptr& graph) { + std::string var_pattern = R"IR( + graph(%input, %dim, %unbiased, %keepdim): + %out: Tensor = aten::var(%input, %dim, %unbiased, %keepdim) + return (%out))IR"; + std::string unpacked_pattern = R"IR( + graph(%input, %dims, %unbiased, %keepdim): + %none: None = prim::Constant() + %false: bool = prim::Constant[value=0]() + %0: int = prim::Constant[value=0]() + %1: int = prim::Constant[value=1]() + %sqrd: Tensor = aten::mul(%input, %input) + %sqrdmean: Tensor = aten::mean(%sqrd, %dims, %keepdim, %none) + %mean: Tensor = aten::mean(%input, %dims, %keepdim, %none) + %meansqrd: Tensor = aten::mul(%mean, %mean) + %var: Tensor = aten::sub(%sqrdmean, %meansqrd, %1) + %varout : Tensor = prim::If(%unbiased) + block0(): + %shape: int[] = aten::size(%input) + %shapet: Tensor = aten::tensor(%shape, %0, %none, %false) + %dim: int = prim::ListUnpack(%dims) + %reduceddims: Tensor = aten::select(%shapet, %0, %dim) + %numel: Tensor = aten::prod(%reduceddims, %dim, %keepdim, %none) + %mul: Tensor = aten::mul(%var, %numel) + %sub: Tensor = aten::sub(%numel, %1, %1) + %v: Tensor = aten::div(%mul, %sub) + -> (%v) + block1(): + -> (%var) + return(%varout))IR"; + + torch::jit::SubgraphRewriter var_rewriter; + var_rewriter.RegisterRewritePattern(var_pattern, unpacked_pattern); + var_rewriter.runOnGraph(graph); + LOG_DEBUG("Post unpack var: " << *graph); + +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace trtorch diff --git a/core/util/trt_util.h b/core/util/trt_util.h index 8a8b399c06..a83a709d49 100644 --- a/core/util/trt_util.h +++ b/core/util/trt_util.h @@ -21,8 +21,6 @@ inline std::ostream& operator<<(std::ostream& os, const nvinfer1::TensorFormat& inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType& dtype) { switch (dtype) { - case nvinfer1::DataType::kBOOL: - return stream << "Bool"; case nvinfer1::DataType::kFLOAT: return stream << "Float32"; case nvinfer1::DataType::kHALF: diff --git a/tests/core/conversion/converters/test_reduce.cpp b/tests/core/conversion/converters/test_reduce.cpp index 834bda49f7..f9aa01a25b 100644 --- a/tests/core/conversion/converters/test_reduce.cpp +++ b/tests/core/conversion/converters/test_reduce.cpp @@ -3,6 +3,8 @@ #include "gtest/gtest.h" #include "tests/util/util.h" #include "torch/csrc/jit/ir/irparser.h" +#include "core/lowering/passes/passes.h" +#include "torch/csrc/jit/passes/common_subexpression_elimination.h" namespace { std::string gen_basic_graph(const std::string& op) { @@ -260,4 +262,214 @@ TEST(Converters, ATenMeanDimNegIndexKeepDimsConvertsCorrectly) { return (%5))IR"; auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); test_body(graph, in); -} \ No newline at end of file +} + +TEST(Converters, UnpackVarLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::var(%x.1, %6, %5, %4) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + trtorch::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, UnpackVarKeepDimsLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::var(%x.1, %6, %5, %5) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + trtorch::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, UnpackVarUnbiasedLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::var(%x.1, %6, %4, %4) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + trtorch::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, UnpackVarUnbiasedKeepDimsLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::var(%x.1, %6, %4, %5) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + trtorch::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, UnpackStdLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::std(%x.1, %6, %5, %4) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + trtorch::core::lowering::passes::UnpackStd(g); + trtorch::core::lowering::passes::UnpackVar(g); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, UnpackStdKeepDimsLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::std(%x.1, %6, %5, %5) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + trtorch::core::lowering::passes::UnpackStd(g); + trtorch::core::lowering::passes::UnpackVar(g); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, UnpackStdUnbiasedLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::std(%x.1, %6, %4, %4) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + trtorch::core::lowering::passes::UnpackStd(g); + trtorch::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, UnpackStdUnbiasedKeepDimsLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::std(%x.1, %6, %4, %5) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + trtorch::core::lowering::passes::UnpackStd(g); + trtorch::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index 52afd7ca09..1acef5863a 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -413,7 +413,6 @@ TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) { %3 : int[] = prim::ListConstruct(%1, %1, %2) %4 : int[] = prim::ListConstruct(%2, %2, %1) %5 : int[][] = prim::ListConstruct(%3, %4) - %5 : int[][][] = prim::ListConstruct(%5) %9 : Tensor = aten::tensor(%5, %1, %7, %8) # bert.py:5:11 %mask.1 : Tensor = aten::to(%9, %44, %7, %8, %8) # bert.py:5:11 %mask.2 : Tensor = trt::const(%mask.1) diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index 7dfdd4cbe8..e715acbe0f 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -39,6 +39,10 @@ lowering_test( name = "test_unpack_hardswish", ) +lowering_test( + name = "test_unpack_reduce_ops", +) + test_suite( name = "lowering_tests", tests = [ @@ -48,6 +52,7 @@ test_suite( ":test_remove_detach_pass", ":test_remove_dropout_pass", ":test_remove_to", - ":test_unpack_hardswish" + ":test_unpack_hardswish", + ":test_unpack_reduce_ops" ], ) diff --git a/tests/core/lowering/test_unpack_reduce_ops.cpp b/tests/core/lowering/test_unpack_reduce_ops.cpp new file mode 100644 index 0000000000..c14075da24 --- /dev/null +++ b/tests/core/lowering/test_unpack_reduce_ops.cpp @@ -0,0 +1,198 @@ +#include +#include "core/compiler.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "core/lowering/passes/passes.h" +#include "core/util/prelude.h" +#include "torch/csrc/jit/passes/common_subexpression_elimination.h" + + +TEST(LoweringPasses, UnpackVarLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::var(%x.1, %6, %5, %4) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randn({1, 3, 3}, {at::kCUDA}); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + trtorch::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, UnpackVarKeepDimsLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::var(%x.1, %6, %5, %5) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randn({1, 3, 3}, {at::kCUDA}); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + trtorch::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, UnpackVarUnbiasedLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::var(%x.1, %6, %4, %4) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randn({4, 3, 3}, {at::kCUDA}); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + trtorch::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, UnpackVarUnbiasedKeepDimsLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::var(%x.1, %6, %4, %5) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randn({4, 3, 3}, {at::kCUDA}); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + trtorch::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, UnpackStdLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::std(%x.1, %6, %5, %4) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randn({1, 3, 3}, {at::kCUDA}); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + trtorch::core::lowering::passes::UnpackStd(g); + trtorch::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, UnpackStdKeepDimsLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::std(%x.1, %6, %5, %5) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randn({1, 3, 3}, {at::kCUDA}); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + trtorch::core::lowering::passes::UnpackStd(g); + trtorch::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, UnpackStdUnbiasedLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::std(%x.1, %6, %4, %4) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randn({4, 3, 3}, {at::kCUDA}); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + trtorch::core::lowering::passes::UnpackStd(g); + trtorch::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + + std::cout << jit_pre_results[0].toTensor() << jit_post_results[0].toTensor() << std::endl; + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, UnpackStdUnbiasedKeepDimsLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::std(%x.1, %6, %4, %5) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randn({4, 3, 3}, {at::kCUDA}); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + trtorch::core::lowering::passes::UnpackStd(g); + trtorch::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} diff --git a/tests/util/util.cpp b/tests/util/util.cpp index 89720728c2..92aa0c5926 100644 --- a/tests/util/util.cpp +++ b/tests/util/util.cpp @@ -23,6 +23,9 @@ bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold) { } bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) { + LOG_GRAPH(a << std::endl << b << std::endl); + std::cout << "Max Difference: " << (a - b).abs().max().item() << std::endl; + return (a - b).abs().max().item() == 0.f; }