Skip to content

Commit

Permalink
fix: fix DPOSPath.save_numpy, DPH5Path.is_file, DPH5Path.is_dir (#3631)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Apr 1, 2024
1 parent 2e6ab1b commit 15e4926
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 3 deletions.
19 changes: 16 additions & 3 deletions deepmd/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
Path,
)
from typing import (
ClassVar,
Dict,
List,
Optional,
)
Expand Down Expand Up @@ -200,7 +202,8 @@ def save_numpy(self, arr: np.ndarray) -> None:
"""
if self.mode == "r":
raise ValueError("Cannot save to read-only path")
np.save(str(self.path), arr)
with self.path.open("wb") as f:
np.save(f, arr)

def glob(self, pattern: str) -> List["DPPath"]:
"""Search path using the glob pattern.
Expand Down Expand Up @@ -354,6 +357,7 @@ def save_numpy(self, arr: np.ndarray) -> None:
del self.root[self._name]
self.root.create_dataset(self._name, data=arr)
self.root.flush()
self._new_keys.append(self._name)

def glob(self, pattern: str) -> List["DPPath"]:
"""Search path using the glob pattern.
Expand Down Expand Up @@ -396,6 +400,14 @@ def _keys(self) -> List[str]:
"""Walk all groups and dataset."""
return self._file_keys(self.root)

__file_new_keys: ClassVar[Dict[h5py.File, List[str]]] = {}

@property
def _new_keys(self):
"""New keys that haven't been cached."""
self.__file_new_keys.setdefault(self.root, [])
return self.__file_new_keys[self.root]

@classmethod
@lru_cache(None)
def _file_keys(cls, file: h5py.File) -> List[str]:
Expand All @@ -406,15 +418,15 @@ def _file_keys(cls, file: h5py.File) -> List[str]:

def is_file(self) -> bool:
"""Check if self is file."""
if self._name not in self._keys:
if self._name not in self._keys and self._name not in self._new_keys:
return False
return isinstance(self.root[self._name], h5py.Dataset)

def is_dir(self) -> bool:
"""Check if self is directory."""
if self._name == "/":
return True
if self._name not in self._keys:
if self._name not in self._keys and self._name not in self._new_keys:
return False
return isinstance(self.root[self._name], h5py.Group)

Expand Down Expand Up @@ -461,3 +473,4 @@ def mkdir(self, parents: bool = False, exist_ok: bool = False) -> None:
self.root.require_group(self._name)
else:
self.root.create_group(self._name)
self._new_keys.append(self._name)
53 changes: 53 additions & 0 deletions source/tests/common/test_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tempfile
import unittest
from pathlib import (
Path,
)

import h5py
import numpy as np

from deepmd.utils.path import (
DPPath,
)


class PathTest:
path: DPPath

def test_numpy(self):
numpy_path = self.path / "testcase"
arr1 = np.ones(3)
self.assertFalse(numpy_path.is_file())
numpy_path.save_numpy(arr1)
self.assertTrue(numpy_path.is_file())
arr2 = numpy_path.load_numpy()
np.testing.assert_array_equal(arr1, arr2)

def test_dir(self):
dir_path = self.path / "testcase"
self.assertFalse(dir_path.is_dir())
dir_path.mkdir()
self.assertTrue(dir_path.is_dir())


class TestOSPath(PathTest, unittest.TestCase):
def setUp(self):
self.tempdir = tempfile.TemporaryDirectory()
self.path = DPPath(self.tempdir.name, "a")

def tearDown(self):
self.tempdir.cleanup()


class TestH5Path(PathTest, unittest.TestCase):
def setUp(self):
self.tempdir = tempfile.TemporaryDirectory()
h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve())
with h5py.File(h5file, "w") as f:
pass
self.path = DPPath(h5file, "a")

def tearDown(self):
self.tempdir.cleanup()

0 comments on commit 15e4926

Please sign in to comment.