From 7f8bb4f70ff9799155f00a395584d38926fa8aa3 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 24 Jun 2024 14:28:24 -0700 Subject: [PATCH] empty tensor moving to default device --- py/torch_tensorrt/dynamo/lowering/_decompositions.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index b96d912897..833a488f60 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -172,6 +172,8 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor: perm = [0] * len(empty_size) for permute_index, permute_element in enumerate(empty_permute): perm[permute_element] = permute_index + default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + kwargs[device] = default_device return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm) @@ -233,7 +235,10 @@ def select_scatter_decomposition( def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor: empty_size = args[0] empty_stride = args[1] - return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride) + default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.as_strided( + torch.empty(empty_size, device=default_device), empty_size, empty_stride + ) def get_decompositions(