Skip to content

Commit

Permalink
[Release/2.1][JIT] Fix typed enum handling in 3.11 (#109807)
Browse files Browse the repository at this point in the history
In Python-3.11+ typed enums (such as `enum.IntEnum`) retain `__new__`,`__str__` and so on method of the base class via `__init__subclass__()` method (see https://docs.python.org/3/whatsnew/3.11.html#enum ), i.e. following code
```python
import sys
import inspect
from enum import Enum

class IntColor(int, Enum):
    RED = 1
    GREEN = 2

class Color(Enum):
    RED = 1
    GREEN = 2

def get_methods(cls):
    def predicate(m):
        if not inspect.isfunction(m) and not inspect.ismethod(m):
            return False
        return m.__name__ in cls.__dict__
    return inspect.getmembers(cls, predicate=predicate)

if __name__ == "__main__":
    print(sys.version)
    print(f"IntColor methods {get_methods(IntColor)}")
    print(f"Color methods {get_methods(Color)}")
```

Returns empty list for both cases for older Python, but on Python-3.11+ it returns list contains of enum constructors and others:
```shell
% conda run -n py310 python bar.py
3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:41:52) [Clang 15.0.7 ]
IntColor methods []
Color methods []
% conda run -n py311 python bar.py
3.11.0 | packaged by conda-forge | (main, Oct 25 2022, 06:21:25) [Clang 14.0.4 ]
IntColor methods [('__format__', <function Enum.__format__ at 0x105006ac0>), ('__new__', <function Enum.__new__ at 0x105006660>), ('__repr__', <function Enum.__repr__ at 0x1050068e0>)]
Color methods []
```

This change allows typed enums to be scriptable on 3.11, by explicitly marking several `enum.Enum` method to be dropped by jit script and adds test that typed enums are jit-scriptable.

Fixes #108933

Cherry-pick of #109717 into release/2.1 branch.
Approved by: https://github.com/atalman, https://github.com/davidberard98

(cherry picked from commit 55685d5)
  • Loading branch information
malfet authored Sep 21, 2023
1 parent c464075 commit 9287a0c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
10 changes: 10 additions & 0 deletions test/jit/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,13 @@ class Color(Enum):
GREEN = 2

torch.jit.script(Color)

# Regression test for https://github.com/pytorch/pytorch/issues/108933
def test_typed_enum(self):
class Color(int, Enum):
RED = 1
GREEN = 2

@torch.jit.script
def is_red(x: Color) -> bool:
return x == Color.RED
12 changes: 11 additions & 1 deletion torch/_jit_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def createResolutionCallbackForClassMethods(cls):
# Skip built-ins, as they do not have global scope nor type hints
# Needed to support `enum.Enum` derived classes in Python-3.11
# That adds `_new_member_` property which is an alias to `__new__`
fns = [fn for fn in fns if not inspect.isbuiltin(fn)]
fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")]
captures = {}

for fn in fns:
Expand Down Expand Up @@ -1491,3 +1491,13 @@ def _extract_tensors(obj):
extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors)
extractor.dump(obj)
return tensors


# In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass
# that were previously dropped. To preserve the behavior, explicitly drop them there

if sys.version_info > (3, 10):
_drop(enum.Enum.__new__)
_drop(enum.Enum.__format__)
_drop(enum.Enum.__repr__)
_drop(enum.Enum.__str__)

0 comments on commit 9287a0c

Please sign in to comment.