Skip to content

Commit

Permalink
[BugFix] Fix hold_out_net (#1719)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 28, 2023
1 parent 07fcfb1 commit 2a72e6d
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,19 @@ class hold_out_net(_context_manager):

def __init__(self, network: nn.Module) -> None:
self.network = network
for p in network.parameters():
self.mode = p.requires_grad
break
else:
self.mode = True

def __enter__(self) -> None:
self.params = TensorDict.from_module(self.network)
self.params.detach().to_module(self.network, return_swap=False)
if self.mode:
self.network.requires_grad_(False)

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.params.to_module(self.network, return_swap=False)
if self.mode:
self.network.requires_grad_()


class hold_out_params(_context_manager):
Expand Down

0 comments on commit 2a72e6d

Please sign in to comment.