Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Faster tensorclass set #880

Merged
merged 3 commits into from
Jul 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 50 additions & 24 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,22 +1347,26 @@ def _set(
f"Cannot set the attribute '{key}', expected attributes are {expected_keys}."
)

self_is_non_tensor = self._is_non_tensor
value_type = type(value)

def set_tensor(
key=key,
value=value,
inplace=inplace,
non_blocking=non_blocking,
non_tensor=False,
):
if self_is_non_tensor:
while is_non_tensor(value):
value = value.data
self._non_tensordict[key] = value
return self
if non_tensor:
value = NonTensorData(value)
# Avoiding key clash, honoring the user input to assign tensor type data to the key
if key in self._non_tensordict.keys():
if inplace:
raise RuntimeError(
f"Cannot update an existing entry of type {type(self._non_tensordict.get(key))} with a value of type {type(value)}."
)
if key in self._non_tensordict:
del self._non_tensordict[key]
# Avoiding key clash, honoring the user input to assign tensor type data to the key
self._tensordict.set(key, value, inplace=inplace, non_blocking=non_blocking)
return self

Expand Down Expand Up @@ -1390,7 +1394,7 @@ def _is_castable(datatype):
target_cls, tuple(tensordict_lib.base._ACCEPTED_CLASSES)
):
try:
if not issubclass(type(value), target_cls):
if not issubclass(value_type, target_cls):
if issubclass(target_cls, torch.Tensor):
# first convert to tensor to make sure that the dtype is preserved
value = torch.as_tensor(value)
Expand All @@ -1405,31 +1409,33 @@ def _is_castable(datatype):
elif value is not None and target_cls is not _AnyType:
cast_val = _cast_funcs[target_cls](value)
return set_tensor(value=cast_val, non_tensor=True)
elif target_cls is _AnyType and _is_castable(type(value)):
elif target_cls is _AnyType and _is_castable(value_type):
return set_tensor()
elif isinstance(value, tuple(tensordict_lib.base._ACCEPTED_CLASSES)):
non_tensor = not (
isinstance(value, _ACCEPTED_CLASSES)
or _is_tensor_collection(value_type)
)
elif issubclass(value_type, torch.Tensor) or _is_tensor_collection(value_type):
return set_tensor()
else:
non_tensor = True

if self._is_non_tensor or value is None:
if self_is_non_tensor or value is None:
# Avoiding key clash, honoring the user input to assign non-tensor data to the key
if key in self._tensordict.keys():
if not self_is_non_tensor and key in self._tensordict.keys():
if inplace:
raise RuntimeError(
f"Cannot update an existing entry of type {type(self._tensordict.get(key))} with a value of type {type(value)}."
f"Cannot update an existing entry of type {type(self._tensordict.get(key))} with a value of type {value_type}."
)
self._tensordict.del_(key)
self._non_tensordict[key] = value
else:
if key in self._tensordict.keys():
if inplace:
if inplace:
if key in self._tensordict.keys():
raise RuntimeError(
f"Cannot update an existing entry of type {type(self._tensordict.get(key))} with a value of type {type(value)}."
f"Cannot update an existing entry of type {type(self._tensordict.get(key))} with a value of type {value_type}."
)
non_tensor = not (
isinstance(value, _ACCEPTED_CLASSES)
or _is_tensor_collection(type(value))
)
set_tensor(value=value, non_tensor=non_tensor)
return set_tensor(value=value, non_tensor=non_tensor)
return self

if isinstance(key, tuple) and len(key):
Expand Down Expand Up @@ -2192,9 +2198,14 @@ def __repr__(self):
@functools.wraps(_eq)
def __eq__(self, other):
if isinstance(other, NonTensorData):
eqval = self.data == other.data
if isinstance(eqval, torch.Tensor):
return eqval
if isinstance(eqval, np.ndarray):
return torch.as_tensor(eqval, device=self.device)
return torch.full(
self.batch_size,
bool(self.data == other.data),
bool(eqval),
device=self.device,
)
return old_eq(self, other)
Expand All @@ -2206,9 +2217,14 @@ def __eq__(self, other):
@functools.wraps(_ne)
def __ne__(self, other):
if isinstance(other, NonTensorData):
neqval = self.data != other.data
if isinstance(neqval, torch.Tensor):
return neqval
if isinstance(neqval, np.ndarray):
return torch.as_tensor(neqval, device=self.device)
return torch.full(
self.batch_size,
bool(self.data != other.data),
bool(neqval),
device=self.device,
)
return _ne(self, other)
Expand All @@ -2220,9 +2236,14 @@ def __ne__(self, other):
@functools.wraps(_xor)
def __xor__(self, other):
if isinstance(other, NonTensorData):
xorval = self.data ^ other.data
if isinstance(xorval, torch.Tensor):
return xorval
if isinstance(xorval, np.ndarray):
return torch.as_tensor(xorval, device=self.device)
return torch.full(
self.batch_size,
bool(self.data ^ other.data),
bool(xorval),
device=self.device,
)
return _xor(self, other)
Expand All @@ -2234,9 +2255,14 @@ def __xor__(self, other):
@functools.wraps(_or)
def __or__(self, other):
if isinstance(other, NonTensorData):
orval = self.data | other.data # yuppie!
if isinstance(orval, torch.Tensor):
return orval
if isinstance(orval, np.ndarray):
return torch.as_tensor(orval, device=self.device)
return torch.full(
self.batch_size,
bool(self.data | other.data),
bool(orval),
device=self.device,
)
return _or(self, other)
Expand Down
Loading