Skip to content

Commit

Permalink
make utils support empty dim order (#2142)
Browse files Browse the repository at this point in the history
Summary:

This update makes util function support empty dim order, to make the empty dim order behave the same as empty memory format (preserve_format).
bypass-github-export-checks

Differential Revision: D54236386
  • Loading branch information
Gasoonjia authored and facebook-github-bot committed Mar 18, 2024
1 parent 45df800 commit 26d87bd
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions exir/dim_order_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List
from typing import List, Optional

import torch

Expand All @@ -27,11 +27,13 @@ def _get_channels_last_dim_order(ndim: int) -> List[int]:
raise AssertionError(f"Unsupported rank: {ndim}")


def get_memory_format(dim_order: List[int]) -> torch.memory_format:
def get_memory_format(dim_order: Optional[List[int]]) -> torch.memory_format:
"""
Given a dim_order try to map it to torch.memory_format
"""
if dim_order == _get_contiguous_dim_order(len(dim_order)):
if dim_order is None:
return torch.preserve_format
elif dim_order == _get_contiguous_dim_order(len(dim_order)):
return torch.contiguous_format
elif len(dim_order) == 4 and dim_order == _get_channels_last_dim_order(
len(dim_order)
Expand All @@ -43,11 +45,15 @@ def get_memory_format(dim_order: List[int]) -> torch.memory_format:
)


def get_dim_order(memory_format: torch.memory_format, ndim: int) -> List[int]:
def get_dim_order(
memory_format: Optional[torch.memory_format], ndim: int
) -> Optional[List[int]]:
"""
Given a memory_format and a tensor rank, generate a dim_order
"""
if memory_format == torch.contiguous_format:
if memory_format in [None, torch.preserve_format]:
return None
elif memory_format == torch.contiguous_format:
return _get_contiguous_dim_order(ndim)
elif memory_format == torch.channels_last:
return _get_channels_last_dim_order(ndim)
Expand Down

0 comments on commit 26d87bd

Please sign in to comment.