From 259e906759a6c2dbcb4ef3a32d90f7bd2113b5ea Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 12 Jul 2024 09:34:16 +0100 Subject: [PATCH 1/2] init --- tensordict/tensorclass.py | 41 +++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index c51701bf3..f34775306 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1347,6 +1347,9 @@ def _set( f"Cannot set the attribute '{key}', expected attributes are {expected_keys}." ) + self_is_non_tensor = self._is_non_tensor + value_type = type(value) + def set_tensor( key=key, value=value, @@ -1354,15 +1357,13 @@ def set_tensor( non_blocking=non_blocking, non_tensor=False, ): + if self_is_non_tensor: + raise RuntimeError("set in NontensorData should not end up here.") if non_tensor: value = NonTensorData(value) - # Avoiding key clash, honoring the user input to assign tensor type data to the key - if key in self._non_tensordict.keys(): - if inplace: - raise RuntimeError( - f"Cannot update an existing entry of type {type(self._non_tensordict.get(key))} with a value of type {type(value)}." - ) + if key in self._non_tensordict: del self._non_tensordict[key] + # Avoiding key clash, honoring the user input to assign tensor type data to the key self._tensordict.set(key, value, inplace=inplace, non_blocking=non_blocking) return self @@ -1390,7 +1391,7 @@ def _is_castable(datatype): target_cls, tuple(tensordict_lib.base._ACCEPTED_CLASSES) ): try: - if not issubclass(type(value), target_cls): + if not issubclass(value_type, target_cls): if issubclass(target_cls, torch.Tensor): # first convert to tensor to make sure that the dtype is preserved value = torch.as_tensor(value) @@ -1405,31 +1406,33 @@ def _is_castable(datatype): elif value is not None and target_cls is not _AnyType: cast_val = _cast_funcs[target_cls](value) return set_tensor(value=cast_val, non_tensor=True) - elif target_cls is _AnyType and _is_castable(type(value)): + elif target_cls is _AnyType and _is_castable(value_type): return set_tensor() - elif isinstance(value, tuple(tensordict_lib.base._ACCEPTED_CLASSES)): + non_tensor = not ( + isinstance(value, _ACCEPTED_CLASSES) + or _is_tensor_collection(value_type) + ) + elif issubclass(value_type, torch.Tensor) or _is_tensor_collection(value_type): return set_tensor() + else: + non_tensor = True - if self._is_non_tensor or value is None: + if self_is_non_tensor or value is None: # Avoiding key clash, honoring the user input to assign non-tensor data to the key if key in self._tensordict.keys(): if inplace: raise RuntimeError( - f"Cannot update an existing entry of type {type(self._tensordict.get(key))} with a value of type {type(value)}." + f"Cannot update an existing entry of type {type(self._tensordict.get(key))} with a value of type {value_type}." ) self._tensordict.del_(key) self._non_tensordict[key] = value else: - if key in self._tensordict.keys(): - if inplace: + if inplace: + if key in self._tensordict.keys(): raise RuntimeError( - f"Cannot update an existing entry of type {type(self._tensordict.get(key))} with a value of type {type(value)}." + f"Cannot update an existing entry of type {type(self._tensordict.get(key))} with a value of type {value_type}." ) - non_tensor = not ( - isinstance(value, _ACCEPTED_CLASSES) - or _is_tensor_collection(type(value)) - ) - set_tensor(value=value, non_tensor=non_tensor) + return set_tensor(value=value, non_tensor=non_tensor) return self if isinstance(key, tuple) and len(key): From e91e3155b6eec91acc07a4f7aa066ad69d9f0882 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 12 Jul 2024 10:06:09 +0100 Subject: [PATCH 2/2] amend --- tensordict/tensorclass.py | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index f34775306..1702e0325 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1358,7 +1358,10 @@ def set_tensor( non_tensor=False, ): if self_is_non_tensor: - raise RuntimeError("set in NontensorData should not end up here.") + while is_non_tensor(value): + value = value.data + self._non_tensordict[key] = value + return self if non_tensor: value = NonTensorData(value) if key in self._non_tensordict: @@ -1419,7 +1422,7 @@ def _is_castable(datatype): if self_is_non_tensor or value is None: # Avoiding key clash, honoring the user input to assign non-tensor data to the key - if key in self._tensordict.keys(): + if not self_is_non_tensor and key in self._tensordict.keys(): if inplace: raise RuntimeError( f"Cannot update an existing entry of type {type(self._tensordict.get(key))} with a value of type {value_type}." @@ -2195,9 +2198,14 @@ def __repr__(self): @functools.wraps(_eq) def __eq__(self, other): if isinstance(other, NonTensorData): + eqval = self.data == other.data + if isinstance(eqval, torch.Tensor): + return eqval + if isinstance(eqval, np.ndarray): + return torch.as_tensor(eqval, device=self.device) return torch.full( self.batch_size, - bool(self.data == other.data), + bool(eqval), device=self.device, ) return old_eq(self, other) @@ -2209,9 +2217,14 @@ def __eq__(self, other): @functools.wraps(_ne) def __ne__(self, other): if isinstance(other, NonTensorData): + neqval = self.data != other.data + if isinstance(neqval, torch.Tensor): + return neqval + if isinstance(neqval, np.ndarray): + return torch.as_tensor(neqval, device=self.device) return torch.full( self.batch_size, - bool(self.data != other.data), + bool(neqval), device=self.device, ) return _ne(self, other) @@ -2223,9 +2236,14 @@ def __ne__(self, other): @functools.wraps(_xor) def __xor__(self, other): if isinstance(other, NonTensorData): + xorval = self.data ^ other.data + if isinstance(xorval, torch.Tensor): + return xorval + if isinstance(xorval, np.ndarray): + return torch.as_tensor(xorval, device=self.device) return torch.full( self.batch_size, - bool(self.data ^ other.data), + bool(xorval), device=self.device, ) return _xor(self, other) @@ -2237,9 +2255,14 @@ def __xor__(self, other): @functools.wraps(_or) def __or__(self, other): if isinstance(other, NonTensorData): + orval = self.data | other.data # yuppie! + if isinstance(orval, torch.Tensor): + return orval + if isinstance(orval, np.ndarray): + return torch.as_tensor(orval, device=self.device) return torch.full( self.batch_size, - bool(self.data | other.data), + bool(orval), device=self.device, ) return _or(self, other)