Skip to content

Commit

Permalink
removing assertion
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Feb 26, 2024
1 parent ac34cb0 commit 3814267
Showing 1 changed file with 0 additions and 14 deletions.
14 changes: 0 additions & 14 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,6 @@ def select_scatter_decomposition(
dim: int,
index: int,
) -> torch.Tensor:
# input_tensor.shape[dim] = torch.le(index, input_tensor.shape[dim])
# check if the dim is less than shape
if input_tensor.shape[dim] < index:
raise AssertionError("The index should not be greater than dim")

# expanding the src_tensor to have the same dimension as input_tensor
# check if the dimension of the src tensor is same as slice tensor
select_tensor = torch.select(input_tensor, dim, index)

if select_tensor.shape != src_tensor.shape:
raise AssertionError(
"The slice tensor shape should be equal to the src tensor shape"
)

unbind_tensors = torch.unbind(input_tensor, dim)
unbind_tensors_list = list(unbind_tensors)
unbind_tensors_list[index] = src_tensor
Expand Down

0 comments on commit 3814267

Please sign in to comment.