From 186ba95c6949e9027a83ae7d45775f15cdad1464 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 23 Nov 2023 09:50:00 +0000 Subject: [PATCH] init --- tensordict/nn/params.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index f0a716fe8..9b713a3f9 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -234,6 +234,11 @@ class TensorDictParams(TensorDictBase, nn.Module): If ``no_convert`` is ``True`` and if non-parameters are present, they will be registered as buffers. Defaults to ``False``. + lock (bool): if ``True``, the tensordict hosted by TensorDictParams will + be locked. This can be useful to avoid unwanted modifications, but + also restricts the operations that can be done over the object (and + can have significant performance impact when `unlock_()` is required). + Defaults to ``False``. Examples: >>> from torch import nn @@ -273,7 +278,9 @@ class TensorDictParams(TensorDictBase, nn.Module): """ - def __init__(self, parameters: TensorDictBase, *, no_convert=False): + def __init__( + self, parameters: TensorDictBase, *, no_convert=False, lock: bool = False + ): super().__init__() if isinstance(parameters, TensorDictParams): parameters = parameters._param_td @@ -283,7 +290,10 @@ def __init__(self, parameters: TensorDictBase, *, no_convert=False): func = _maybe_make_param else: func = _maybe_make_param_or_buffer - self._param_td = _apply_leaves(self._param_td, lambda x: func(x)).lock_() + self._param_td = _apply_leaves(self._param_td, lambda x: func(x)) + self._lock = lock + if lock: + self._param_td.lock_() self._reset_params() self._is_locked = False self._locked_tensordicts = []