Skip to content

Commit

Permalink
fix: Fix how ITensorList is detected
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
  • Loading branch information
peri044 authored and bowang007 committed Apr 28, 2023
1 parent d7cb415 commit 013934a
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions core/conversion/var/Var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,31 @@ bool Var::isITensor() const {
}
}

bool Var::isITensorList() {
// Unpack the Var as a List and check if each entry is a custom class since
// ITensors are stored in CustomClassHolder
auto ival_list = ptr_.ivalue->toList();
for (int i = 0; i < ival_list.size(); i++) {
if (!ival_list.get(i).isCustomClass()) {
return false;
}
}
return true;
}

std::vector<nvinfer1::ITensor*> Var::unwrapToITensorList() {
TORCHTRT_CHECK(
isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name());
TORCHTRT_CHECK(isITensorList(), "Expected IValue to be an ITensorList");
auto ivalue_list = ptr_.ivalue->toList();
std::vector<nvinfer1::ITensor*> outputs;
for (int i = 0; i < ivalue_list.size(); i++) {
auto element = ivalue_list.get(i).toCustomClass<TensorContainer>()->tensor();
outputs.push_back(std::move(element));
}
return outputs;
}

bool Var::isIValue() const {
if (type_ == Type::kIValue) {
return true;
Expand Down

0 comments on commit 013934a

Please sign in to comment.