Skip to content

Commit

Permalink
[Refactor] Better weight update in collectors (#1723)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 30, 2023
1 parent 6c27bdb commit d545364
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,10 @@ def _get_policy_and_device(
device = torch.device(device) if device is not None else policy_device
get_weights_fn = None
if policy_device != device:
param_and_buf = dict(policy.named_parameters())
param_and_buf.update(dict(policy.named_buffers()))
param_and_buf = TensorDict.from_module(policy, as_module=True)

def get_weights_fn(param_and_buf=param_and_buf):
return TensorDict(param_and_buf, []).apply(lambda x: x.data)
return param_and_buf.data

policy_cast = deepcopy(policy).requires_grad_(False).to(device)
# here things may break bc policy.to("cuda") gives us weights on cuda:0 (same
Expand All @@ -308,9 +307,9 @@ def update_policy_weights_(
"""
if policy_weights is not None:
self.policy_weights.apply(lambda x: x.data).update_(policy_weights)
self.policy_weights.data.update_(policy_weights)
elif self.get_weights_fn is not None:
self.policy_weights.apply(lambda x: x.data).update_(self.get_weights_fn())
self.policy_weights.data.update_(self.get_weights_fn())

def __iter__(self) -> Iterator[TensorDictBase]:
return self.iterator()
Expand Down Expand Up @@ -559,10 +558,7 @@ def __init__(
)

if isinstance(self.policy, nn.Module):
self.policy_weights = TensorDict(dict(self.policy.named_parameters()), [])
self.policy_weights.update(
TensorDict(dict(self.policy.named_buffers()), [])
)
self.policy_weights = TensorDict.from_module(self.policy, as_module=True)
else:
self.policy_weights = TensorDict({}, [])

Expand Down Expand Up @@ -1200,9 +1196,9 @@ def device_err_msg(device_name, devices_list):
)
self._policy_dict[_device] = _policy
if isinstance(_policy, nn.Module):
param_dict = dict(_policy.named_parameters())
param_dict.update(_policy.named_buffers())
self._policy_weights_dict[_device] = TensorDict(param_dict, [])
self._policy_weights_dict[_device] = TensorDict.from_module(
_policy, as_module=True
)
else:
self._policy_weights_dict[_device] = TensorDict({}, [])

Expand Down Expand Up @@ -1288,11 +1284,9 @@ def frames_per_batch_worker(self):
def update_policy_weights_(self, policy_weights=None) -> None:
for _device in self._policy_dict:
if policy_weights is not None:
self._policy_weights_dict[_device].apply(lambda x: x.data).update_(
policy_weights
)
self._policy_weights_dict[_device].data.update_(policy_weights)
elif self._get_weights_fn_dict[_device] is not None:
self._policy_weights_dict[_device].update_(
self._policy_weights_dict[_device].data.update_(
self._get_weights_fn_dict[_device]()
)

Expand Down

0 comments on commit d545364

Please sign in to comment.