diff --git a/pfrl/agent.py b/pfrl/agent.py index 128e10d90..ecc5ac652 100644 --- a/pfrl/agent.py +++ b/pfrl/agent.py @@ -102,6 +102,11 @@ def __save(self, dirname: str, ancestors: List[Any]): ), "Avoid an infinite loop" attr_value.__save(os.path.join(dirname, attr), ancestors) else: + if isinstance( + attr_value, + (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel), + ): + attr_value = attr_value.module torch.save( attr_value.state_dict(), os.path.join(dirname, "{}.pt".format(attr)) ) @@ -125,6 +130,11 @@ def __load(self, dirname: str, ancestors: List[Any]) -> None: ), "Avoid an infinite loop" attr_value.load(os.path.join(dirname, attr)) else: + if isinstance( + attr_value, + (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel), + ): + attr_value = attr_value.module attr_value.load_state_dict( torch.load( os.path.join(dirname, "{}.pt".format(attr)), map_location diff --git a/tests/test_agent.py b/tests/test_agent.py index 0d34e24c0..9f9271148 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -91,3 +91,34 @@ def test_loop2(self): # Otherwise it seems to raise OSError: [Errno 63] File name too long with self.assertRaises(AssertionError): parent1.save(dirname) + + def test_with_data_parallel(self): + parent = Parent() + parent.link.param.detach().numpy()[:] = 1 + parent.child.link.param.detach().numpy()[:] = 2 + parent.link = torch.nn.DataParallel(parent.link) + + # Save + dirname = tempfile.mkdtemp() + parent.save(dirname) + self.assertTrue(os.path.isdir(dirname)) + self.assertTrue(os.path.isfile(os.path.join(dirname, "link.pt"))) + self.assertTrue(os.path.isdir(os.path.join(dirname, "child"))) + self.assertTrue(os.path.isfile(os.path.join(dirname, "child", "link.pt"))) + + # Load Parent without data parallel + parent = Parent() + self.assertEqual(int(parent.link.param.detach().numpy()), 0) + self.assertEqual(int(parent.child.link.param.detach().numpy()), 0) + parent.load(dirname) + self.assertEqual(int(parent.link.param.detach().numpy()), 1) + self.assertEqual(int(parent.child.link.param.detach().numpy()), 2) + + # Load Parent with data parallel + parent = Parent() + parent.link = torch.nn.DataParallel(parent.link) + self.assertEqual(int(parent.link.module.param.detach().numpy()), 0) + self.assertEqual(int(parent.child.link.param.detach().numpy()), 0) + parent.load(dirname) + self.assertEqual(int(parent.link.module.param.detach().numpy()), 1) + self.assertEqual(int(parent.child.link.param.detach().numpy()), 2)