diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 1f1fc04e58d..c3e7dbc68ce 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -363,13 +363,19 @@ class hold_out_net(_context_manager): def __init__(self, network: nn.Module) -> None: self.network = network + for p in network.parameters(): + self.mode = p.requires_grad + break + else: + self.mode = True def __enter__(self) -> None: - self.params = TensorDict.from_module(self.network) - self.params.detach().to_module(self.network, return_swap=False) + if self.mode: + self.network.requires_grad_(False) def __exit__(self, exc_type, exc_val, exc_tb) -> None: - self.params.to_module(self.network, return_swap=False) + if self.mode: + self.network.requires_grad_() class hold_out_params(_context_manager):