From a80a95e9eb66573a3666161eb2348fa98a1d4df9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 7 Jan 2025 15:58:16 +0000 Subject: [PATCH] [BugFix] Fix unitary ops for tensorclass ghstack-source-id: 2d117645769890b72f5856f68acbe1b48015cfbb Pull Request resolved: https://github.com/pytorch/tensordict/pull/1164 --- tensordict/tensorclass.py | 13 ++++++++++--- test/test_tensorclass.py | 41 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 141e47aaf..957c1ddd4 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -163,14 +163,19 @@ def __subclasscheck__(self, subclass): ] # Methods to be executed from tensordict, any ref to self means 'self._tensordict' +_FALLBACK_METHOD_FROM_TD_FORCE = [ + "__ge__", + "__gt__", + "__le__", + "__lt__", + "__ror__", +] _FALLBACK_METHOD_FROM_TD = [ "__abs__", "__add__", "__and__", "__bool__", "__eq__", - "__ge__", - "__gt__", "__iadd__", "__imul__", "__invert__", @@ -185,7 +190,6 @@ def __subclasscheck__(self, subclass): "__radd__", "__rand__", "__rmul__", - "__ror__", "__rpow__", "__rsub__", "__rtruediv__", @@ -240,6 +244,7 @@ def __subclasscheck__(self, subclass): "auto_batch_size_", "auto_device_", "bitwise_and", + "bool", "ceil", "ceil_", "chunk", @@ -814,6 +819,8 @@ def __torch_function__( for method_name in _FALLBACK_METHOD_FROM_TD: if not hasattr(cls, method_name): setattr(cls, method_name, _wrap_td_method(method_name)) + for method_name in _FALLBACK_METHOD_FROM_TD_FORCE: + setattr(cls, method_name, _wrap_td_method(method_name)) for method_name in _FALLBACK_METHOD_FROM_TD_NOWRAP: if not hasattr(cls, method_name): setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True)) diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index d2b808111..e2ccc9989 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -2588,6 +2588,47 @@ class X: assert (x.mul(2) == (x * 2)).all() assert (x.div(2) == (x / 2)).all() + def test_logic_and_right_ops(self): + @tensorclass + class MyClass: + x: str + + c = MyClass(torch.randn(10)) + _ = c < 0 + _ = c > 0 + _ = c <= 0 + _ = c >= 0 + _ = c != 0 + + _ = c.bool() ^ True + _ = True ^ c.bool() + + _ = c.bool() | False + _ = False | c.bool() + + _ = c.bool() & False + _ = False & c.bool() + + _ = abs(c) + + _ = c + 1 + _ = 1 + c + c += 1 + + _ = c * 1 + _ = 1 * c + + _ = c - 1 + _ = 1 - c + c -= 1 + + _ = c / 1 + _ = 1 / c + + _ = c**1 + # not implemented + # 1 ** c + class TestSubClassing: def test_subclassing(self):