Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix get for nestedkeys with default in tensorclass #1211

Merged
merged 2 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6384,6 +6384,11 @@ def _default_get(self, key: NestedKey, default: Any = NO_DEFAULT) -> CompatibleT
_KEY_ERROR.format(key, type(self).__name__, sorted(self.keys()))
)

@overload
def get(self, key): ...
@overload
def get(self, key, default): ...

def get(self, key: NestedKey, *args, **kwargs) -> CompatibleType:
"""Gets the value stored with the input key.

Expand Down Expand Up @@ -6439,8 +6444,17 @@ def _get_tuple_maybe_non_tensor(self, key, default):
return result.data
return result

@overload
def get_at(self, key, index): ...

@overload
def get_at(self, key, index, default): ...

def get_at(
self, key: NestedKey, index: IndexType, default: CompatibleType = NO_DEFAULT
self,
key: NestedKey,
*args,
**kwargs,
) -> CompatibleType:
"""Get the value of a tensordict from the key `key` at the index `idx`.

Expand All @@ -6463,7 +6477,30 @@ def get_at(
key = _unravel_key_to_tuple(key)
if not key:
raise KeyError(_GENERIC_NESTED_ERR.format(key))
# must be a tuple

try:
if len(args):
index = args[0]
args = args[1:]
else:
index = kwargs.pop("index")
except KeyError:
raise TypeError("index argument missing from get_at")

# Find what the default is
if args:
default = args[0]
if len(args) > 1 or kwargs:
raise TypeError("only one (keyword) argument is allowed.")
elif kwargs:
default = kwargs.pop("default")
if args or kwargs:
raise TypeError("only one (keyword) argument is allowed.")
elif _GET_DEFAULTS_TO_NONE:
default = None
else:
default = NO_DEFAULT

return self._get_at_tuple(key, index, default)

def _get_at_str(self, key, idx, default):
Expand Down
74 changes: 58 additions & 16 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@
from tensordict._torch_func import TD_HANDLED_FUNCTIONS
from tensordict.base import (
_ACCEPTED_CLASSES,
_GET_DEFAULTS_TO_NONE,
_is_tensor_collection,
_register_tensor_class,
CompatibleType,
)
from tensordict.utils import ( # @manual=//pytorch/tensordict:_C
_GENERIC_NESTED_ERR,
_is_dataclass as is_dataclass,
_is_json_serializable,
_is_tensorclass,
Expand Down Expand Up @@ -2238,7 +2240,7 @@ def _set_at_(
return self._tensordict.set_at_(key, value, idx, non_blocking=non_blocking)


def _get(self, key: NestedKey, default: Any = NO_DEFAULT):
def _get(self, key: NestedKey, *args, **kwargs):
"""Gets the value stored with the input key.

Args:
Expand All @@ -2250,25 +2252,65 @@ def _get(self, key: NestedKey, default: Any = NO_DEFAULT):
value stored with the input key

"""
if isinstance(key, str):
key = (key,)
key = _unravel_key_to_tuple(key)
if not key:
raise KeyError(_GENERIC_NESTED_ERR.format(key))

# Find what the default is
if args:
default = args[0]
if len(args) > 1 or kwargs:
raise TypeError("only one (keyword) argument is allowed.")
elif kwargs:
default = kwargs.pop("default")
if args or kwargs:
raise TypeError("only one (keyword) argument is allowed.")
elif _GET_DEFAULTS_TO_NONE:
default = None
else:
default = NO_DEFAULT

if isinstance(key, tuple):
try:
if len(key) > 1:
return getattr(self, key[0]).get(key[1:])
return getattr(self, key[0])
except AttributeError:
if default is NO_DEFAULT:
raise
return default
raise ValueError(f"Supported type for key are str and tuple, got {type(key)}")
try:
if len(key) > 1:
return getattr(self, key[0]).get(key[1:], default=default)
return getattr(self, key[0])
except (AttributeError, KeyError):
if default is NO_DEFAULT:
raise
return default


def _get_at(self, key: NestedKey, *args, **kwargs):
key = _unravel_key_to_tuple(key)
if not key:
raise KeyError(_GENERIC_NESTED_ERR.format(key))

try:
if len(args):
index = args[0]
args = args[1:]
else:
index = kwargs.pop("index")
except KeyError:
raise TypeError("index argument missing from get_at")

# Find what the default is
if args:
default = args[0]
if len(args) > 1 or kwargs:
raise TypeError("only one (keyword) argument is allowed.")
elif kwargs:
default = kwargs.pop("default")
if args or kwargs:
raise TypeError("only one (keyword) argument is allowed.")
elif _GET_DEFAULTS_TO_NONE:
default = None
else:
default = NO_DEFAULT

def _get_at(self, key: NestedKey, idx, default: Any = NO_DEFAULT):
try:
return self.get(key, NO_DEFAULT)[idx]
except AttributeError:
return self.get(key, NO_DEFAULT)[index]
except (AttributeError, KeyError):
if default is NO_DEFAULT:
raise
return default
Expand Down
13 changes: 12 additions & 1 deletion tensordict/tensorclass.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,20 @@ class TensorClass:
def set_(
self, key: NestedKey, item: CompatibleType, *, non_blocking: bool = False
) -> T: ...
@overload
def get(self, key): ...
@overload
def get(self, key, default): ...
def get(self, key: NestedKey, *args, **kwargs) -> CompatibleType: ...
@overload
def get_at(self, key, index): ...
@overload
def get_at(self, key, index, default): ...
def get_at(
self, key: NestedKey, index: IndexType, default: CompatibleType = ...
self,
key: NestedKey,
*args,
**kwargs,
) -> CompatibleType: ...
def get_item_shape(self, key: NestedKey): ...
def update(
Expand Down
118 changes: 72 additions & 46 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
TensorDictBase,
)
from tensordict._lazy import _PermutedTensorDict, _ViewedTensorDict
from tensordict.base import _GENERIC_NESTED_ERR
from torch import Tensor

# Capture all warnings
Expand Down Expand Up @@ -233,6 +234,77 @@ class MyDataClass:


class TestTensorClass:
def test_get_default(self):
@tensorclass
class Data:
td: TensorDict
a: torch.Tensor

data = Data(td=TensorDict(), a=torch.zeros(()))
assert data.get("a") is not None
assert data.get("b") is None
assert data.get("b", "else") == "else"

with pytest.raises(KeyError, match=_GENERIC_NESTED_ERR.format(())):
data.get(("td", str)) # something unexpected!

assert data.get(("td", "missing"), "else") == "else"
assert data.get(("td", "missing")) is None

data = data.expand(10)
assert data.get_at("a", 0) is not None
assert data.get_at("b", 0) is None
assert data.get_at("b", 0, "else") == "else"

assert data.get_at(("td", "missing"), 0, "else") == "else"
assert data.get_at(("td", "missing"), 0) is None

def test_decorator(self):
@tensorclass
class MyClass:
X: torch.Tensor
y: Any

obj = MyClass(X=torch.zeros(2), y="a string!", batch_size=[])
assert not obj.is_locked
with obj.lock_():
assert obj.is_locked
with obj.unlock_():
assert not obj.is_locked
assert obj.is_locked
assert not obj.is_locked

def test_to_dict(self):
@tensorclass
class TestClass:
my_tensor: torch.Tensor
my_str: str

test_class = TestClass(
my_tensor=torch.tensor([1, 2, 3]), my_str="hello", batch_size=[3]
)

assert (
test_class
== TestClass.from_dict(test_class.to_dict(), auto_batch_size=True)
).all()

# Currently we don't test non-tensor in __eq__ because __eq__ can break with arrays and such
# test_class2 = TestClass(
# my_tensor=torch.tensor([1, 2, 3]), my_str="goodbye", batch_size=[3]
# )
#
# assert not (test_class == TestClass.from_dict(test_class2.to_dict())).all()

test_class3 = TestClass(
my_tensor=torch.tensor([1, 2, 0]), my_str="hello", batch_size=[3]
)

assert not (
test_class
== TestClass.from_dict(test_class3.to_dict(), auto_batch_size=True)
).all()

def test_all_any(self):
@tensorclass
class MyClass1:
Expand Down Expand Up @@ -2229,52 +2301,6 @@ def test_to(self):
assert td_double.device == torch.device("cpu")


def test_decorator():
@tensorclass
class MyClass:
X: torch.Tensor
y: Any

obj = MyClass(X=torch.zeros(2), y="a string!", batch_size=[])
assert not obj.is_locked
with obj.lock_():
assert obj.is_locked
with obj.unlock_():
assert not obj.is_locked
assert obj.is_locked
assert not obj.is_locked


def test_to_dict():
@tensorclass
class TestClass:
my_tensor: torch.Tensor
my_str: str

test_class = TestClass(
my_tensor=torch.tensor([1, 2, 3]), my_str="hello", batch_size=[3]
)

assert (
test_class == TestClass.from_dict(test_class.to_dict(), auto_batch_size=True)
).all()

# Currently we don't test non-tensor in __eq__ because __eq__ can break with arrays and such
# test_class2 = TestClass(
# my_tensor=torch.tensor([1, 2, 3]), my_str="goodbye", batch_size=[3]
# )
#
# assert not (test_class == TestClass.from_dict(test_class2.to_dict())).all()

test_class3 = TestClass(
my_tensor=torch.tensor([1, 2, 0]), my_str="hello", batch_size=[3]
)

assert not (
test_class == TestClass.from_dict(test_class3.to_dict(), auto_batch_size=True)
).all()


@tensorclass(autocast=True)
class AutoCast:
tensor: torch.Tensor
Expand Down