Skip to content

Commit

Permalink
[BugFix,CI] Fix sporadically failing tests in CI (#2098)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 23, 2024
1 parent 0ea236d commit fcbc6ea
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
10 changes: 6 additions & 4 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def _get_mock_input_td(
)
return td

@retry(AssertionError, 3)
@retry(AssertionError, 5)
@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
Expand All @@ -906,7 +906,7 @@ def test_multiagent_mlp(
n_agent_inputs,
n_agent_outputs=2,
):
torch.manual_seed(0)
torch.manual_seed(1)
mlp = MultiAgentMLP(
n_agent_inputs=n_agent_inputs,
n_agent_outputs=n_agent_outputs,
Expand Down Expand Up @@ -938,8 +938,10 @@ def test_multiagent_mlp(
elif i > 0:
assert torch.allclose(out[..., i, :], out2[..., i, :])

obs = torch.randn(*batch, 1, n_agent_inputs).expand(
*batch, n_agents, n_agent_inputs
obs = (
torch.randn(*batch, 1, n_agent_inputs)
.expand(*batch, n_agents, n_agent_inputs)
.clone()
)
out = mlp(obs)
for i in range(n_agents):
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7881,7 +7881,7 @@ def test_parallelenv_vecnorm(self):
parallel_sd = parallel_env.state_dict()
assert "worker0" in parallel_sd
worker_sd = parallel_sd["worker0"]
td = worker_sd["_extra_state"]["td"]
td = worker_sd["transforms.1._extra_state"]["td"]
queue_out.put("start")
msg = queue_in.get(timeout=TIMEOUT)
assert msg == "first round"
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/rlhf/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _load_dataset(self):
from datasets import load_dataset, load_from_disk

if self.from_disk:
dataset = load_from_disk(self.dataset_name)[self.split]
dataset = load_from_disk(str(self.dataset_name))[self.split]
else:
dataset = load_dataset(self.dataset_name, split=self.split)
if self.split.startswith("valid"):
Expand Down

0 comments on commit fcbc6ea

Please sign in to comment.