Skip to content

Commit

Permalink
update metadataupdates
Browse files Browse the repository at this point in the history
  • Loading branch information
BY571 committed Nov 30, 2023
1 parent 03b865f commit 6d0c1f0
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,14 +511,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
loss_actor_bc, bc_metadata = self.actor_bc_loss(td_device)
loss_actor, actor_metadata = self.actor_loss(td_device)
loss_alpha, alpha_metadata = self.alpha_loss(td_device)
metadata.update(
{
**bc_metadata,
**cql_metadata,
**actor_metadata,
**alpha_metadata,
}
)
metadata.update(bc_metadata)
metadata.update(cql_metadata)
metadata.update(actor_metadata)
metadata.update(alpha_metadata)
tensordict_reshape.set(
self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values
)
Expand Down

0 comments on commit 6d0c1f0

Please sign in to comment.