diff --git a/exir/dim_order_utils.py b/exir/dim_order_utils.py index 0aae6e92308..562244b6a48 100644 --- a/exir/dim_order_utils.py +++ b/exir/dim_order_utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List +from typing import List, Optional import torch @@ -27,11 +27,13 @@ def _get_channels_last_dim_order(ndim: int) -> List[int]: raise AssertionError(f"Unsupported rank: {ndim}") -def get_memory_format(dim_order: List[int]) -> torch.memory_format: +def get_memory_format(dim_order: Optional[List[int]]) -> torch.memory_format: """ Given a dim_order try to map it to torch.memory_format """ - if dim_order == _get_contiguous_dim_order(len(dim_order)): + if dim_order is None: + return torch.preserve_format + elif dim_order == _get_contiguous_dim_order(len(dim_order)): return torch.contiguous_format elif len(dim_order) == 4 and dim_order == _get_channels_last_dim_order( len(dim_order) @@ -43,11 +45,15 @@ def get_memory_format(dim_order: List[int]) -> torch.memory_format: ) -def get_dim_order(memory_format: torch.memory_format, ndim: int) -> List[int]: +def get_dim_order( + memory_format: Optional[torch.memory_format], ndim: int +) -> Optional[List[int]]: """ Given a memory_format and a tensor rank, generate a dim_order """ - if memory_format == torch.contiguous_format: + if memory_format in [None, torch.preserve_format]: + return None + elif memory_format == torch.contiguous_format: return _get_contiguous_dim_order(ndim) elif memory_format == torch.channels_last: return _get_channels_last_dim_order(ndim)