Skip to content

Commit

Permalink
[Feature] NonTensorStack.from_list
Browse files Browse the repository at this point in the history
ghstack-source-id: 4839e805482832b55fe57af35c06ed4a29e6d026
Pull Request resolved: #1107
  • Loading branch information
vmoens committed Nov 24, 2024
1 parent 7dda788 commit a99906f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __subclasscheck__(self, subclass):
"any",
"apply",
"apply_",
"as_tensor",
"asin",
"asin_",
"atan",
Expand Down Expand Up @@ -3114,6 +3115,18 @@ def maybe_to_stack(self):
stack_dim=self.stack_dim,
)

@classmethod
def from_list(cls, non_tensors: List[Any]):
# Use local function because refers to cls
def _maybe_from_list(nontensor):
if isinstance(nontensor, list):
return cls.from_list(nontensor)
if is_non_tensor(nontensor):
return nontensor
return NonTensorData(nontensor)

return cls(*[_maybe_from_list(nontensor) for nontensor in non_tensors])

@classmethod
def from_nontensordata(cls, non_tensor: NonTensorData):
data = non_tensor.data
Expand Down
13 changes: 13 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10659,6 +10659,19 @@ def test_comparison(self, non_tensor_data):
("nested", "bool")
)

def test_from_list(self):
nd = NonTensorStack.from_list(
[[True, "b", torch.randn(())], ["another", 0, NonTensorData("final")]]
)
assert isinstance(nd, NonTensorStack)
assert nd.shape == (2, 3)
assert nd[0, 0].data
assert nd[0, 1].data == "b"
assert isinstance(nd[0, 2].data, torch.Tensor)
assert nd[1, 0].data == "another"
assert nd[1, 1].data == 0
assert nd[1, 2].data == "final"

def test_non_tensor_call(self):
td0 = TensorDict({"a": 0, "b": 0})
td1 = TensorDict({"a": 1, "b": 1})
Expand Down

0 comments on commit a99906f

Please sign in to comment.