diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 268be721235..86d5461b15f 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -511,14 +511,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_actor_bc, bc_metadata = self.actor_bc_loss(td_device) loss_actor, actor_metadata = self.actor_loss(td_device) loss_alpha, alpha_metadata = self.alpha_loss(td_device) - metadata.update( - { - **bc_metadata, - **cql_metadata, - **actor_metadata, - **alpha_metadata, - } - ) + metadata.update(bc_metadata) + metadata.update(cql_metadata) + metadata.update(actor_metadata) + metadata.update(alpha_metadata) tensordict_reshape.set( self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values )