From 38744bc4b58b1d5eb27368b3696d2d778c0e118a Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 4 Aug 2021 20:13:15 -0700 Subject: [PATCH] fix(aten::tensor): Last dim doesnt always get written right Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/evaluators/eval_util.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/core/conversion/evaluators/eval_util.cpp b/core/conversion/evaluators/eval_util.cpp index 7e9411ac80..19898605b0 100644 --- a/core/conversion/evaluators/eval_util.cpp +++ b/core/conversion/evaluators/eval_util.cpp @@ -129,7 +129,7 @@ void storeLastDimension( auto n = sizes[dim]; auto seq_size = obj.size(); checkSequenceSize(n, dim, seq_size); - for (const auto i : c10::irange(n)) { + for (int64_t i = 0; i < n; i++) { *(DTYPE*)data = obj[i].to(); data += strides[dim] * elementSize; } @@ -189,17 +189,17 @@ void recursiveStore( } else if (obj.isBoolList()) { storeLastDimension(data, sizes, strides, dim, tenElementSize, seq); } else if (obj.isDoubleList()) { - if (tenElementSize == static_cast(elementSize(at::ScalarType::Double))) { + if (tenElementSize == static_cast(c10::elementSize(at::ScalarType::Double))) { storeLastDimension(data, sizes, strides, dim, tenElementSize, seq); - } else if (tenElementSize == static_cast(elementSize(at::ScalarType::Float))) { + } else if (tenElementSize == static_cast(c10::elementSize(at::ScalarType::Float))) { storeLastDimensionFloat(data, sizes, strides, dim, tenElementSize, seq); - } else if (tenElementSize == static_cast(elementSize(at::ScalarType::Half))) { + } else if (tenElementSize == static_cast(c10::elementSize(at::ScalarType::Half))) { storeLastDimensionHalf(data, sizes, strides, dim, tenElementSize, seq); } else { - TORCH_INTERNAL_ASSERT(false); + TRTORCH_THROW_ERROR("Found unsupported data type in arguments for aten::tensor"); } } else { - TORCH_INTERNAL_ASSERT(false); + TRTORCH_THROW_ERROR("Found unsupported data type in arguments for aten::tensor"); } } } @@ -231,9 +231,11 @@ at::Tensor createTensorFromList( const torch::jit::IValue& dtype, const torch::jit::IValue& device) { auto elem_type = data.type(); + /// Recurse down nested lists to find base type while (auto list_type = elem_type->cast()) { elem_type = list_type->getElementType(); } + /// Gets shape of tensor to be created auto sizes = compute_sizes(data); checkListInputType(elem_type, sizes.size() == 1 && sizes[0] == 0); at::ScalarType initial_scalar_type = c10::scalarTypeFromJitType(elem_type);