From 5546b57ed668e422636385aeba01e0a6b72e9d06 Mon Sep 17 00:00:00 2001 From: roger-creus Date: Fri, 5 Jul 2024 16:14:22 -0400 Subject: [PATCH] Fixed shape for MultiStep returns + Distributional loss --- torchrl/objectives/dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 7b35598c474..86a11855ca8 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -560,7 +560,7 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: support = support.to("cpu") pns_a = pns_a.to("cpu") - Tz = reward + (1 - terminated.to(reward.dtype)) * discount * support + Tz = reward + (1 - terminated.to(reward.dtype)) * discount.unsqueeze(-1) * support.repeat(batch_size, 1) if Tz.shape != torch.Size([batch_size, atoms]): raise RuntimeError( "Tz shape must be torch.Size([batch_size, atoms]), "