Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 24, 2024
1 parent 003be12 commit 43d0a36
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def __torch_function__(
cls.load_state_dict = _load_state_dict
cls._memmap_ = _memmap_

cls.__enter__ = __enter__
cls.__exit__ = __exit__

# Memmap
cls.memmap_like = TensorDictBase.memmap_like
cls.memmap_ = TensorDictBase.memmap_
Expand Down Expand Up @@ -421,6 +424,14 @@ def _load_memmap(cls, prefix: Path, metadata: dict):
return cls._from_tensordict(td, non_tensordict)


def __enter__(self, *args, **kwargs):
return self._tensordict.__enter__(*args, **kwargs)


def __exit__(self, *args, **kwargs):
return self._tensordict.__exit__(*args, **kwargs)


def _getstate(self) -> dict[str, Any]:
"""Returns a state dict which consists of tensor and non_tensor dicts for serialization.
Expand Down
16 changes: 16 additions & 0 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1785,6 +1785,22 @@ def test_to(self):
assert isinstance(td.get("c")[0], self.TensorClass)


def test_decorator():
@tensorclass
class MyClass:
X: torch.Tensor
y: Any

obj = MyClass(X=torch.zeros(2), y="a string!", batch_size=[])
assert not obj.is_locked
with obj.lock_():
assert obj.is_locked
with obj.unlock_():
assert not obj.is_locked
assert obj.is_locked
assert not obj.is_locked


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 comments on commit 43d0a36

Please sign in to comment.