From 6d3aaaaf7ec3e5ed2775a7af74e4379dd4c258c2 Mon Sep 17 00:00:00 2001 From: Sidhant Sundrani Date: Thu, 26 Nov 2020 23:53:22 +0530 Subject: [PATCH] fix loss test case for batch size variation (#402) --- tests/models/rl/unit/test_reinforce.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/rl/unit/test_reinforce.py b/tests/models/rl/unit/test_reinforce.py index 9eb5ca8796..72a3641d9c 100644 --- a/tests/models/rl/unit/test_reinforce.py +++ b/tests/models/rl/unit/test_reinforce.py @@ -37,9 +37,9 @@ def setUp(self) -> None: def test_loss(self): """Test the reinforce loss function""" - batch_states = torch.rand(32, 4) - batch_actions = torch.rand(32).long() - batch_qvals = torch.rand(32) + batch_states = torch.rand(16, 4) + batch_actions = torch.rand(16).long() + batch_qvals = torch.rand(16) loss = self.model.loss(batch_states, batch_actions, batch_qvals)