From 6d0c1f056a5181ab160a5fc0641cfe35b4956095 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 30 Nov 2023 13:00:55 +0100 Subject: [PATCH] update metadataupdates --- torchrl/objectives/cql.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 268be721235..86d5461b15f 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -511,14 +511,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_actor_bc, bc_metadata = self.actor_bc_loss(td_device) loss_actor, actor_metadata = self.actor_loss(td_device) loss_alpha, alpha_metadata = self.alpha_loss(td_device) - metadata.update( - { - **bc_metadata, - **cql_metadata, - **actor_metadata, - **alpha_metadata, - } - ) + metadata.update(bc_metadata) + metadata.update(cql_metadata) + metadata.update(actor_metadata) + metadata.update(alpha_metadata) tensordict_reshape.set( self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values )