Skip to content

Commit

Permalink
[Refactor] Do not lock nested tensordict in tensordictparams (#568)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 23, 2023
1 parent 57fc236 commit dc4eb6b
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ class TensorDictParams(TensorDictBase, nn.Module):
If ``no_convert`` is ``True`` and if non-parameters are present, they
will be registered as buffers.
Defaults to ``False``.
lock (bool): if ``True``, the tensordict hosted by TensorDictParams will
be locked. This can be useful to avoid unwanted modifications, but
also restricts the operations that can be done over the object (and
can have significant performance impact when `unlock_()` is required).
Defaults to ``False``.
Examples:
>>> from torch import nn
Expand Down Expand Up @@ -273,7 +278,9 @@ class TensorDictParams(TensorDictBase, nn.Module):
"""

def __init__(self, parameters: TensorDictBase, *, no_convert=False):
def __init__(
self, parameters: TensorDictBase, *, no_convert=False, lock: bool = False
):
super().__init__()
if isinstance(parameters, TensorDictParams):
parameters = parameters._param_td
Expand All @@ -283,7 +290,10 @@ def __init__(self, parameters: TensorDictBase, *, no_convert=False):
func = _maybe_make_param
else:
func = _maybe_make_param_or_buffer
self._param_td = _apply_leaves(self._param_td, lambda x: func(x)).lock_()
self._param_td = _apply_leaves(self._param_td, lambda x: func(x))
self._lock = lock
if lock:
self._param_td.lock_()
self._reset_params()
self._is_locked = False
self._locked_tensordicts = []
Expand Down

0 comments on commit dc4eb6b

Please sign in to comment.