Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: plain text model format #4025

Merged
merged 1 commit into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepmd/backend/dpmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DPModelBackend(Backend):
Backend.Feature.DEEP_EVAL | Backend.Feature.NEIGHBOR_STAT | Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[List[str]] = [".dp"]
suffixes: ClassVar[List[str]] = [".dp", ".yaml", ".yml"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
Expand Down
84 changes: 68 additions & 16 deletions deepmd/dpmodel/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
from datetime import (
datetime,
)
from pathlib import (
Path,
)
from typing import (
Callable,
)

import h5py
import numpy as np
import yaml

try:
from deepmd._version import version as __version__
Expand All @@ -33,6 +38,8 @@
The model object after traversing.
"""
if isinstance(model_obj, dict):
if model_obj.get("@is_variable", False):
return callback(model_obj)
for kk, vv in model_obj.items():
model_obj[kk] = traverse_model_dict(
vv, callback, is_variable=is_variable or kk == "@variables"
Expand Down Expand Up @@ -78,22 +85,48 @@
The model dict to save.
"""
model_dict = model_dict.copy()
variable_counter = Counter()
with h5py.File(filename, "w") as f:
filename_extension = Path(filename).suffix
extra_dict = {
"software": "deepmd-kit",
"version": __version__,
# use UTC+0 time
"time": str(datetime.utcnow()),
}
if filename_extension == ".dp":
variable_counter = Counter()
with h5py.File(filename, "w") as f:
model_dict = traverse_model_dict(
model_dict,
lambda x: f.create_dataset(
f"variable_{variable_counter():04d}", data=x
).name,
)
save_dict = {
**extra_dict,
**model_dict,
}
f.attrs["json"] = json.dumps(save_dict, separators=(",", ":"))
elif filename_extension in {".yaml", ".yml"}:
model_dict = traverse_model_dict(
model_dict,
lambda x: f.create_dataset(
f"variable_{variable_counter():04d}", data=x
).name,
lambda x: {
"@class": "np.ndarray",
"@is_variable": True,
"@version": 1,
"dtype": x.dtype.name,
"value": x.tolist(),
},
)
save_dict = {
"software": "deepmd-kit",
"version": __version__,
# use UTC+0 time
"time": str(datetime.utcnow()),
**model_dict,
}
f.attrs["json"] = json.dumps(save_dict, separators=(",", ":"))
with open(filename, "w") as f:
yaml.safe_dump(
{
**extra_dict,
**model_dict,
},
f,
)
else:
raise ValueError(f"Unknown filename extension: {filename_extension}")

Check warning on line 129 in deepmd/dpmodel/utils/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/serialization.py#L129

Added line #L129 was not covered by tests


def load_dp_model(filename: str) -> dict:
Expand All @@ -109,7 +142,26 @@
dict
The loaded model dict, including meta information.
"""
with h5py.File(filename, "r") as f:
model_dict = json.loads(f.attrs["json"])
model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy())
filename_extension = Path(filename).suffix
if filename_extension == ".dp":
with h5py.File(filename, "r") as f:
model_dict = json.loads(f.attrs["json"])
model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy())
elif filename_extension in {".yaml", ".yml"}:

def convert_numpy_ndarray(x):
if isinstance(x, dict) and x.get("@class") == "np.ndarray":
dtype = np.dtype(x["dtype"])
value = np.asarray(x["value"], dtype=dtype)
return value
return x

Check warning on line 157 in deepmd/dpmodel/utils/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/serialization.py#L157

Added line #L157 was not covered by tests

with open(filename) as f:
model_dict = yaml.safe_load(f)
model_dict = traverse_model_dict(
model_dict,
convert_numpy_ndarray,
)
else:
raise ValueError(f"Unknown filename extension: {filename_extension}")

Check warning on line 166 in deepmd/dpmodel/utils/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/serialization.py#L166

Added line #L166 was not covered by tests
return model_dict
8 changes: 5 additions & 3 deletions doc/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ While `.pth` and `.pt` are the same in the PyTorch package, they have different
This backend is only for development and should not take into production.
:::

- Model filename extension: `.dp`
- Model filename extension: `.dp`, `.yaml`, `.yml`

DP is a reference backend for development, which uses pure [NumPy](https://numpy.org/) to implement models without using any heavy deep-learning frameworks.
Due to the limitation of NumPy, it doesn't support gradient calculation and thus cannot be used for training.
As a reference backend, it is not aimed at the best performance, but only the correct results.
The DP backend uses [HDF5](https://docs.h5py.org/) to store model serialization data, which is backend-independent.
Only Python inference interface can load this format.
The DP backend has two formats, both of which are backend-independent:
The `.dp` format uses [HDF5](https://docs.h5py.org/) to store model serialization data, which has good performance.
The `.yaml` or `.yml` use [YAML](https://yaml.org/) to save the data as plain texts, which is easy to read for human beings.
Only Python inference interface can load these formats.

NumPy 1.21 or above is required.

Expand Down
10 changes: 10 additions & 0 deletions source/tests/common/dpmodel/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def setUp(self) -> None:
],
}
self.filename = "test_dp_dpmodel.dp"
self.filename_yaml = "test_dp_dpmodel.yaml"

def test_save_load_model(self):
save_dp_model(self.filename, {"model": deepcopy(self.model_dict)})
Expand All @@ -291,6 +292,15 @@ def test_save_load_model(self):
assert "software" in model
assert "version" in model

def test_save_load_model_yaml(self):
save_dp_model(self.filename_yaml, {"model": deepcopy(self.model_dict)})
model = load_dp_model(self.filename_yaml)
np.testing.assert_equal(model["model"], self.model_dict)
assert "software" in model
assert "version" in model

def tearDown(self) -> None:
if os.path.exists(self.filename):
os.remove(self.filename)
if os.path.exists(self.filename_yaml):
os.remove(self.filename_yaml)