diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 56ddacd864..815b738af2 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -119,11 +119,16 @@ auto aten_registrations TRTORCH_UNUSED = // Device? device=None, bool? pin_memory=None) -> (Tensor) [](const torch::jit::Node* n, kwargs& args) -> c10::optional { auto options = torch::TensorOptions() - .dtype(c10::ScalarType(args.at(n->output(1)).unwrapToInt())) .layout(torch::kStrided) .device(torch::kCUDA); + if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) { + options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt())); + } + auto out_tensor = torch::zeros(args.at(n->input(0)).unwrapToIntList().vec(), options); + std::cout << out_tensor << std::endl; + std::cout << out_tensor.sizes() << std::endl; return out_tensor; }}) .evaluator({c10::Symbol::fromQualString("aten::slice"), diff --git a/tests/core/conversion/evaluators/test_aten_evaluators.cpp b/tests/core/conversion/evaluators/test_aten_evaluators.cpp index 1d235572b2..93f89ec471 100644 --- a/tests/core/conversion/evaluators/test_aten_evaluators.cpp +++ b/tests/core/conversion/evaluators/test_aten_evaluators.cpp @@ -36,4 +36,23 @@ TEST(Evaluators, DivFloatEvaluatesCorrectly) { auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); ASSERT_TRUE(jit_results[0] == trt_results[0]); +} + +TEST(Evaluators, ZerosEvaluatesCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : None = prim::Constant() # :0:0 + %3 : int[] = aten::size(%x.1) # :7:9 + %z.1 : Tensor = aten::zeros(%3, %2, %2, %2, %2) # experiments/test_zeros.py:8:12 + return (%z.1))IR"; + + auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA}); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in}); + auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in}); + + ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor())); } \ No newline at end of file