diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index d2b249cb5..2c0397422 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -196,6 +196,9 @@ def __torch_function__( cls.load_state_dict = _load_state_dict cls._memmap_ = _memmap_ + cls.__enter__ = __enter__ + cls.__exit__ = __exit__ + # Memmap cls.memmap_like = TensorDictBase.memmap_like cls.memmap_ = TensorDictBase.memmap_ @@ -421,6 +424,14 @@ def _load_memmap(cls, prefix: Path, metadata: dict): return cls._from_tensordict(td, non_tensordict) +def __enter__(self, *args, **kwargs): + return self._tensordict.__enter__(*args, **kwargs) + + +def __exit__(self, *args, **kwargs): + return self._tensordict.__exit__(*args, **kwargs) + + def _getstate(self) -> dict[str, Any]: """Returns a state dict which consists of tensor and non_tensor dicts for serialization. diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index b430674b8..d1195c40e 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -1785,6 +1785,22 @@ def test_to(self): assert isinstance(td.get("c")[0], self.TensorClass) +def test_decorator(): + @tensorclass + class MyClass: + X: torch.Tensor + y: Any + + obj = MyClass(X=torch.zeros(2), y="a string!", batch_size=[]) + assert not obj.is_locked + with obj.lock_(): + assert obj.is_locked + with obj.unlock_(): + assert not obj.is_locked + assert obj.is_locked + assert not obj.is_locked + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)