From 15e4926f5742cb4e1a9004fec6eb901e830c5940 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 1 Apr 2024 04:47:05 -0400 Subject: [PATCH] fix: fix DPOSPath.save_numpy, DPH5Path.is_file, DPH5Path.is_dir (#3631) Signed-off-by: Jinzhe Zeng --- deepmd/utils/path.py | 19 ++++++++++-- source/tests/common/test_path.py | 53 ++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 3 deletions(-) create mode 100644 source/tests/common/test_path.py diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index afe14703a0..858e31a39d 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -11,6 +11,8 @@ Path, ) from typing import ( + ClassVar, + Dict, List, Optional, ) @@ -200,7 +202,8 @@ def save_numpy(self, arr: np.ndarray) -> None: """ if self.mode == "r": raise ValueError("Cannot save to read-only path") - np.save(str(self.path), arr) + with self.path.open("wb") as f: + np.save(f, arr) def glob(self, pattern: str) -> List["DPPath"]: """Search path using the glob pattern. @@ -354,6 +357,7 @@ def save_numpy(self, arr: np.ndarray) -> None: del self.root[self._name] self.root.create_dataset(self._name, data=arr) self.root.flush() + self._new_keys.append(self._name) def glob(self, pattern: str) -> List["DPPath"]: """Search path using the glob pattern. @@ -396,6 +400,14 @@ def _keys(self) -> List[str]: """Walk all groups and dataset.""" return self._file_keys(self.root) + __file_new_keys: ClassVar[Dict[h5py.File, List[str]]] = {} + + @property + def _new_keys(self): + """New keys that haven't been cached.""" + self.__file_new_keys.setdefault(self.root, []) + return self.__file_new_keys[self.root] + @classmethod @lru_cache(None) def _file_keys(cls, file: h5py.File) -> List[str]: @@ -406,7 +418,7 @@ def _file_keys(cls, file: h5py.File) -> List[str]: def is_file(self) -> bool: """Check if self is file.""" - if self._name not in self._keys: + if self._name not in self._keys and self._name not in self._new_keys: return False return isinstance(self.root[self._name], h5py.Dataset) @@ -414,7 +426,7 @@ def is_dir(self) -> bool: """Check if self is directory.""" if self._name == "/": return True - if self._name not in self._keys: + if self._name not in self._keys and self._name not in self._new_keys: return False return isinstance(self.root[self._name], h5py.Group) @@ -461,3 +473,4 @@ def mkdir(self, parents: bool = False, exist_ok: bool = False) -> None: self.root.require_group(self._name) else: self.root.create_group(self._name) + self._new_keys.append(self._name) diff --git a/source/tests/common/test_path.py b/source/tests/common/test_path.py new file mode 100644 index 0000000000..7dcb3a031c --- /dev/null +++ b/source/tests/common/test_path.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tempfile +import unittest +from pathlib import ( + Path, +) + +import h5py +import numpy as np + +from deepmd.utils.path import ( + DPPath, +) + + +class PathTest: + path: DPPath + + def test_numpy(self): + numpy_path = self.path / "testcase" + arr1 = np.ones(3) + self.assertFalse(numpy_path.is_file()) + numpy_path.save_numpy(arr1) + self.assertTrue(numpy_path.is_file()) + arr2 = numpy_path.load_numpy() + np.testing.assert_array_equal(arr1, arr2) + + def test_dir(self): + dir_path = self.path / "testcase" + self.assertFalse(dir_path.is_dir()) + dir_path.mkdir() + self.assertTrue(dir_path.is_dir()) + + +class TestOSPath(PathTest, unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.TemporaryDirectory() + self.path = DPPath(self.tempdir.name, "a") + + def tearDown(self): + self.tempdir.cleanup() + + +class TestH5Path(PathTest, unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.path = DPPath(h5file, "a") + + def tearDown(self): + self.tempdir.cleanup()