From 1a7f43ac7f66898c388c9696c14186f7075e4f60 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 23 Nov 2023 10:54:34 +0000 Subject: [PATCH] [Performance] Faster params and buffer registration in TensorDictParams (#569) --- tensordict/nn/params.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 9b713a3f9..6cedbf41d 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -317,18 +317,21 @@ def _apply_get_post_hook(self, val): def _reset_params(self): parameters = self._param_td param_keys = [] + params = [] buffer_keys = [] + buffers = [] for key, value in parameters.items(True, True): + # flatten key + if isinstance(key, tuple): + key = "_".join(key) if isinstance(value, nn.Parameter): param_keys.append(key) + params.append(value) else: buffer_keys.append(key) - self.__dict__["_parameters"] = ( - parameters.select(*param_keys).flatten_keys("_").to_dict() - ) - self.__dict__["_buffers"] = ( - parameters.select(*buffer_keys).flatten_keys("_").to_dict() - ) + buffers.append(value) + self.__dict__["_parameters"] = dict(zip(param_keys, params)) + self.__dict__["_buffers"] = dict(zip(buffer_keys, buffers)) @classmethod def __torch_function__(