From 3a24a2eba27cc217385591cdeaa181bda24e8e19 Mon Sep 17 00:00:00 2001 From: Shuji Suzuki Date: Thu, 3 Sep 2020 09:05:41 +0900 Subject: [PATCH 1/3] add support dp and ddp in agent.save and agent.load --- pfrl/agent.py | 6 ++++++ tests/test_agent.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/pfrl/agent.py b/pfrl/agent.py index 128e10d90..acf0ef148 100644 --- a/pfrl/agent.py +++ b/pfrl/agent.py @@ -102,6 +102,9 @@ def __save(self, dirname: str, ancestors: List[Any]): ), "Avoid an infinite loop" attr_value.__save(os.path.join(dirname, attr), ancestors) else: + if isinstance(attr_value, (torch.nn.parallel.DistributedDataParallel, + torch.nn.DataParallel)): + attr_value = attr_value.module torch.save( attr_value.state_dict(), os.path.join(dirname, "{}.pt".format(attr)) ) @@ -125,6 +128,9 @@ def __load(self, dirname: str, ancestors: List[Any]) -> None: ), "Avoid an infinite loop" attr_value.load(os.path.join(dirname, attr)) else: + if isinstance(attr_value, (torch.nn.parallel.DistributedDataParallel, + torch.nn.DataParallel)): + attr_value = attr_value.module attr_value.load_state_dict( torch.load( os.path.join(dirname, "{}.pt".format(attr)), map_location diff --git a/tests/test_agent.py b/tests/test_agent.py index 0d34e24c0..38a6aae33 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -91,3 +91,36 @@ def test_loop2(self): # Otherwise it seems to raise OSError: [Errno 63] File name too long with self.assertRaises(AssertionError): parent1.save(dirname) + + def test_with_data_parallel(self): + parent = Parent() + parent.link.param.detach().numpy()[:] = 1 + parent.child.link.param.detach().numpy()[:] = 2 + parent.link = torch.nn.DataParallel(parent.link) + + # Save + dirname = tempfile.mkdtemp() + parent.save(dirname) + self.assertTrue(os.path.isdir(dirname)) + self.assertTrue(os.path.isfile(os.path.join(dirname, "link.pt"))) + self.assertTrue(os.path.isdir(os.path.join(dirname, "child"))) + self.assertTrue(os.path.isfile(os.path.join(dirname, "child", "link.pt"))) + + # Load Parent without data parallel + parent = Parent() + self.assertEqual(int(parent.link.param.detach().numpy()), 0) + self.assertEqual(int(parent.child.link.param.detach().numpy()), 0) + parent.load(dirname) + self.assertEqual(int(parent.link.param.detach().numpy()), 1) + self.assertEqual(int(parent.child.link.param.detach().numpy()), 2) + + # Load Parent with data parallel + parent = Parent() + parent.link = torch.nn.DataParallel(parent.link) + self.assertEqual(int(parent.link.module.param.detach().numpy()), 0) + self.assertEqual(int(parent.child.link.param.detach().numpy()), 0) + parent.load(dirname) + self.assertEqual(int(parent.link.module.param.detach().numpy()), 1) + self.assertEqual(int(parent.child.link.param.detach().numpy()), 2) + + From a3033cfb8a39a2ab994cf625c3b4ca701cf0c4ec Mon Sep 17 00:00:00 2001 From: Shuji Suzuki Date: Thu, 3 Sep 2020 09:17:15 +0900 Subject: [PATCH 2/3] fix flake8 errors --- tests/test_agent.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_agent.py b/tests/test_agent.py index 38a6aae33..9f9271148 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -122,5 +122,3 @@ def test_with_data_parallel(self): parent.load(dirname) self.assertEqual(int(parent.link.module.param.detach().numpy()), 1) self.assertEqual(int(parent.child.link.param.detach().numpy()), 2) - - From 867c1bbc2bb258326ec1483208e325b7c422dfdb Mon Sep 17 00:00:00 2001 From: Shuji Suzuki Date: Tue, 13 Oct 2020 10:39:42 +0900 Subject: [PATCH 3/3] fix black errors --- pfrl/agent.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pfrl/agent.py b/pfrl/agent.py index acf0ef148..ecc5ac652 100644 --- a/pfrl/agent.py +++ b/pfrl/agent.py @@ -102,8 +102,10 @@ def __save(self, dirname: str, ancestors: List[Any]): ), "Avoid an infinite loop" attr_value.__save(os.path.join(dirname, attr), ancestors) else: - if isinstance(attr_value, (torch.nn.parallel.DistributedDataParallel, - torch.nn.DataParallel)): + if isinstance( + attr_value, + (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel), + ): attr_value = attr_value.module torch.save( attr_value.state_dict(), os.path.join(dirname, "{}.pt".format(attr)) @@ -128,8 +130,10 @@ def __load(self, dirname: str, ancestors: List[Any]) -> None: ), "Avoid an infinite loop" attr_value.load(os.path.join(dirname, attr)) else: - if isinstance(attr_value, (torch.nn.parallel.DistributedDataParallel, - torch.nn.DataParallel)): + if isinstance( + attr_value, + (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel), + ): attr_value = attr_value.module attr_value.load_state_dict( torch.load(