diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index b96d912897..a917aadda2 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -172,6 +172,8 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor: perm = [0] * len(empty_size) for permute_index, permute_element in enumerate(empty_permute): perm[permute_element] = permute_index + default_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + kwargs[device] = default_device return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm) @@ -233,7 +235,8 @@ def select_scatter_decomposition( def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor: empty_size = args[0] empty_stride = args[1] - return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride) + default_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + return torch.as_strided(torch.empty(empty_size, device = default_device), empty_size, empty_stride) def get_decompositions(