From 5431f29c6a7c4209fa704661f7f2770e38cf981e Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 24 Jun 2024 14:28:24 -0700 Subject: [PATCH] empty tensor moving to default device --- py/torch_tensorrt/dynamo/lowering/_decompositions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index b96d912897..a917aadda2 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -172,6 +172,8 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor: perm = [0] * len(empty_size) for permute_index, permute_element in enumerate(empty_permute): perm[permute_element] = permute_index + default_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + kwargs[device] = default_device return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm) @@ -233,7 +235,8 @@ def select_scatter_decomposition( def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor: empty_size = args[0] empty_stride = args[1] - return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride) + default_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + return torch.as_strided(torch.empty(empty_size, device = default_device), empty_size, empty_stride) def get_decompositions(