diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 03c9632a1..0a5107e9e 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -390,7 +390,7 @@ def _new_unsafe( cls, parameters: TensorDictBase, *, - no_convert=False, + no_convert=None, lock: bool = False, params: dict | None = None, buffers: dict | None = None, @@ -399,24 +399,28 @@ def _new_unsafe( if is_compiling(): return TensorDictParams(parameters, no_convert="skip", lock=lock) - self = TensorDictParams.__new__(cls) - nn.Module.__init__(self) - if parameters is None: parameters = kwargs - elif kwargs: - raise TypeError( - f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args." - ) if isinstance(parameters, dict): - parameters = TensorDict._new_unsafe(parameters) + parameters = TensorDict._new_unsafe(parameters, **kwargs) + if no_convert is None: + # Then _new_unsafe is called from somewhere that doesn't know + # that it's a TDParams and we return a TensorDict (eg, torch.gather) + return parameters elif isinstance(parameters, TensorDictParams): + if kwargs: + raise TypeError( + f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args." + ) params = dict(parameters._parameters) buffers = dict(parameters._buffers) parameters = parameters._param_td no_convert = "skip" + self = TensorDictParams.__new__(cls) + nn.Module.__init__(self) + self._param_td = parameters self.no_convert = no_convert if no_convert != "skip":