Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 16, 2025
1 parent f6f38e0 commit 8a7c4bc
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def _new_unsafe(
cls,
parameters: TensorDictBase,
*,
no_convert=False,
no_convert=None,
lock: bool = False,
params: dict | None = None,
buffers: dict | None = None,
Expand All @@ -399,24 +399,28 @@ def _new_unsafe(
if is_compiling():
return TensorDictParams(parameters, no_convert="skip", lock=lock)

self = TensorDictParams.__new__(cls)
nn.Module.__init__(self)

if parameters is None:
parameters = kwargs
elif kwargs:
raise TypeError(
f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args."
)

if isinstance(parameters, dict):
parameters = TensorDict._new_unsafe(parameters)
parameters = TensorDict._new_unsafe(parameters, **kwargs)
if no_convert is None:
# Then _new_unsafe is called from somewhere that doesn't know
# that it's a TDParams and we return a TensorDict (eg, torch.gather)
return parameters
elif isinstance(parameters, TensorDictParams):
if kwargs:
raise TypeError(
f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args."
)
params = dict(parameters._parameters)
buffers = dict(parameters._buffers)
parameters = parameters._param_td
no_convert = "skip"

self = TensorDictParams.__new__(cls)
nn.Module.__init__(self)

self._param_td = parameters
self.no_convert = no_convert
if no_convert != "skip":
Expand Down

0 comments on commit 8a7c4bc

Please sign in to comment.