Skip to content

Commit

Permalink
fix(aten::tensor): Last dim doesnt always get written right
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Aug 5, 2021
1 parent 90af26e commit 38744bc
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions core/conversion/evaluators/eval_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ void storeLastDimension(
auto n = sizes[dim];
auto seq_size = obj.size();
checkSequenceSize(n, dim, seq_size);
for (const auto i : c10::irange(n)) {
for (int64_t i = 0; i < n; i++) {
*(DTYPE*)data = obj[i].to<DTYPE>();
data += strides[dim] * elementSize;
}
Expand Down Expand Up @@ -189,17 +189,17 @@ void recursiveStore(
} else if (obj.isBoolList()) {
storeLastDimension<bool>(data, sizes, strides, dim, tenElementSize, seq);
} else if (obj.isDoubleList()) {
if (tenElementSize == static_cast<int>(elementSize(at::ScalarType::Double))) {
if (tenElementSize == static_cast<int>(c10::elementSize(at::ScalarType::Double))) {
storeLastDimension<double>(data, sizes, strides, dim, tenElementSize, seq);
} else if (tenElementSize == static_cast<int>(elementSize(at::ScalarType::Float))) {
} else if (tenElementSize == static_cast<int>(c10::elementSize(at::ScalarType::Float))) {
storeLastDimensionFloat(data, sizes, strides, dim, tenElementSize, seq);
} else if (tenElementSize == static_cast<int>(elementSize(at::ScalarType::Half))) {
} else if (tenElementSize == static_cast<int>(c10::elementSize(at::ScalarType::Half))) {
storeLastDimensionHalf(data, sizes, strides, dim, tenElementSize, seq);
} else {
TORCH_INTERNAL_ASSERT(false);
TRTORCH_THROW_ERROR("Found unsupported data type in arguments for aten::tensor");
}
} else {
TORCH_INTERNAL_ASSERT(false);
TRTORCH_THROW_ERROR("Found unsupported data type in arguments for aten::tensor");
}
}
}
Expand Down Expand Up @@ -231,9 +231,11 @@ at::Tensor createTensorFromList(
const torch::jit::IValue& dtype,
const torch::jit::IValue& device) {
auto elem_type = data.type();
/// Recurse down nested lists to find base type
while (auto list_type = elem_type->cast<c10::ListType>()) {
elem_type = list_type->getElementType();
}
/// Gets shape of tensor to be created
auto sizes = compute_sizes(data);
checkListInputType(elem_type, sizes.size() == 1 && sizes[0] == 0);
at::ScalarType initial_scalar_type = c10::scalarTypeFromJitType(elem_type);
Expand Down

0 comments on commit 38744bc

Please sign in to comment.