diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index ab8b99e2e..f14d38b67 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -533,7 +533,7 @@ def __torch_function__( if not hasattr(cls, "names"): cls.require = property(_names, _names_setter) if not _is_non_tensor and not hasattr(cls, "data"): - cls.data = property(_data) + cls.data = property(_data, _data_setter) if not hasattr(cls, "grad"): cls.grad = property(_grad) if not hasattr(cls, "to_dict"): @@ -1719,9 +1719,21 @@ def _names(self) -> torch.Size: def _data(self): + # We allow data to be a field of the class too + if "data" in self.__dataclass_fields__: + data = self._tensordict.get("data", None) + if data is None: + data = self._non_tensordict.get("data") + return data return self._from_tensordict(self._tensordict.data, self._non_tensordict) +def _data_setter(self, new_data): + if "data" in self.__dataclass_fields__: + return self.set("data", new_data) + raise AttributeError("property 'data' is read-only.") + + def _grad(self): grad = self._tensordict._grad if grad is None: