From fcbc6ea17964a4445d61500aa1eb6ed88771a24e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 23 Apr 2024 14:05:42 +0100 Subject: [PATCH] [BugFix,CI] Fix sporadically failing tests in CI (#2098) --- test/test_modules.py | 10 ++++++---- test/test_transforms.py | 2 +- torchrl/data/rlhf/dataset.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/test/test_modules.py b/test/test_modules.py index de4333a3254..c9984e178c5 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -891,7 +891,7 @@ def _get_mock_input_td( ) return td - @retry(AssertionError, 3) + @retry(AssertionError, 5) @pytest.mark.parametrize("n_agents", [1, 3]) @pytest.mark.parametrize("share_params", [True, False]) @pytest.mark.parametrize("centralised", [True, False]) @@ -906,7 +906,7 @@ def test_multiagent_mlp( n_agent_inputs, n_agent_outputs=2, ): - torch.manual_seed(0) + torch.manual_seed(1) mlp = MultiAgentMLP( n_agent_inputs=n_agent_inputs, n_agent_outputs=n_agent_outputs, @@ -938,8 +938,10 @@ def test_multiagent_mlp( elif i > 0: assert torch.allclose(out[..., i, :], out2[..., i, :]) - obs = torch.randn(*batch, 1, n_agent_inputs).expand( - *batch, n_agents, n_agent_inputs + obs = ( + torch.randn(*batch, 1, n_agent_inputs) + .expand(*batch, n_agents, n_agent_inputs) + .clone() ) out = mlp(obs) for i in range(n_agents): diff --git a/test/test_transforms.py b/test/test_transforms.py index 124e63aaad3..36408bf4964 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -7881,7 +7881,7 @@ def test_parallelenv_vecnorm(self): parallel_sd = parallel_env.state_dict() assert "worker0" in parallel_sd worker_sd = parallel_sd["worker0"] - td = worker_sd["_extra_state"]["td"] + td = worker_sd["transforms.1._extra_state"]["td"] queue_out.put("start") msg = queue_in.get(timeout=TIMEOUT) assert msg == "first round" diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index 8f039b317fc..c1411b81a09 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -164,7 +164,7 @@ def _load_dataset(self): from datasets import load_dataset, load_from_disk if self.from_disk: - dataset = load_from_disk(self.dataset_name)[self.split] + dataset = load_from_disk(str(self.dataset_name))[self.split] else: dataset = load_dataset(self.dataset_name, split=self.split) if self.split.startswith("valid"):