From 5277b054c404b2695131442b5a6d238ae9b396f8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 28 Nov 2023 15:44:16 +0000 Subject: [PATCH] init --- torchrl/objectives/utils.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 1f1fc04e58d..c3e7dbc68ce 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -363,13 +363,19 @@ class hold_out_net(_context_manager): def __init__(self, network: nn.Module) -> None: self.network = network + for p in network.parameters(): + self.mode = p.requires_grad + break + else: + self.mode = True def __enter__(self) -> None: - self.params = TensorDict.from_module(self.network) - self.params.detach().to_module(self.network, return_swap=False) + if self.mode: + self.network.requires_grad_(False) def __exit__(self, exc_type, exc_val, exc_tb) -> None: - self.params.to_module(self.network, return_swap=False) + if self.mode: + self.network.requires_grad_() class hold_out_params(_context_manager):