Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cast input and output types at model's interface #3352

Merged
merged 9 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
"int64": np.int64,
"default": GLOBAL_NP_FLOAT_PRECISION,
}
RESERVED_PRECISON_DICT = {
np.float16: "float16",
np.float32: "float32",
np.float64: "float64",
np.int32: "int32",
np.int64: "int64",
}
DEFAULT_PRECISION = "float64"


Expand Down
95 changes: 84 additions & 11 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Optional,
Tuple,
)

import numpy as np

from deepmd.dpmodel.common import (
GLOBAL_NP_FLOAT_PRECISION,
PRECISION_DICT,
RESERVED_PRECISON_DICT,
NativeOP,
)
from deepmd.dpmodel.output_def import (
Expand Down Expand Up @@ -59,6 +64,9 @@
*args,
**kwargs,
)
self.precision_dict = PRECISION_DICT
self.reverse_precision_dict = RESERVED_PRECISON_DICT
self.global_np_float_precision = GLOBAL_NP_FLOAT_PRECISION

Check warning on line 69 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L67-L69

Added lines #L67 - L69 were not covered by tests

def model_output_def(self):
"""Get the output def for the model."""
Expand Down Expand Up @@ -115,15 +123,19 @@

"""
nframes, nloc = atype.shape[:2]
if box is not None:
cc, bb, fp, ap, input_prec = self.input_type_cast(

Check warning on line 126 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L126

Added line #L126 was not covered by tests
coord, box=box, fparam=fparam, aparam=aparam
)
del coord, box, fparam, aparam
if bb is not None:

Check warning on line 130 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L129-L130

Added lines #L129 - L130 were not covered by tests
coord_normalized = normalize_coord(
coord.reshape(nframes, nloc, 3),
box.reshape(nframes, 3, 3),
cc.reshape(nframes, nloc, 3),
bb.reshape(nframes, 3, 3),
)
else:
coord_normalized = coord.copy()
coord_normalized = cc.copy()

Check warning on line 136 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L136

Added line #L136 was not covered by tests
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
coord_normalized, atype, box, self.get_rcut()
coord_normalized, atype, bb, self.get_rcut()
)
nlist = build_neighbor_list(
extended_coord,
Expand All @@ -139,8 +151,8 @@
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
model_predict = communicate_extended_output(
Expand All @@ -149,6 +161,7 @@
mapping,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)

Check warning on line 164 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L164

Added line #L164 was not covered by tests
return model_predict

def call_lower(
Expand Down Expand Up @@ -192,22 +205,82 @@
nframes, nall = extended_atype.shape[:2]
extended_coord = extended_coord.reshape(nframes, -1, 3)
nlist = self.format_nlist(extended_coord, extended_atype, nlist)
cc_ext, _, fp, ap, input_prec = self.input_type_cast(

Check warning on line 208 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L208

Added line #L208 was not covered by tests
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam

Check warning on line 211 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L211

Added line #L211 was not covered by tests
atomic_ret = self.forward_atomic(
extended_coord,
cc_ext,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
fparam=fp,
aparam=ap,
)
model_predict = fit_output_to_model_output(
atomic_ret,
self.fitting_output_def(),
extended_coord,
cc_ext,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)

Check warning on line 226 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L226

Added line #L226 was not covered by tests
return model_predict

def input_type_cast(
self,
coord: np.ndarray,
box: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> Tuple[
np.ndarray,
Optional[np.ndarray],
Optional[np.ndarray],
Optional[np.ndarray],
str,
]:
"""Cast the input data to global float type."""
input_prec = self.reverse_precision_dict[

Check warning on line 243 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L243

Added line #L243 was not covered by tests
self.precision_dict[coord.dtype.name]
]
###
### type checking would not pass jit, convert to coord prec anyway
###
_lst: List[Optional[np.ndarray]] = [

Check warning on line 249 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L249

Added line #L249 was not covered by tests
vv.astype(coord.dtype) if vv is not None else None
for vv in [box, fparam, aparam]
]
box, fparam, aparam = _lst
if (

Check warning on line 254 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L253-L254

Added lines #L253 - L254 were not covered by tests
input_prec
== self.reverse_precision_dict[self.global_np_float_precision]
):
return coord, box, fparam, aparam, input_prec

Check warning on line 258 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L258

Added line #L258 was not covered by tests
else:
pp = self.global_np_float_precision
return (

Check warning on line 261 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L260-L261

Added lines #L260 - L261 were not covered by tests
coord.astype(pp),
box.astype(pp) if box is not None else None,
fparam.astype(pp) if fparam is not None else None,
aparam.astype(pp) if aparam is not None else None,
input_prec,
)

def output_type_cast(
self,
model_ret: Dict[str, np.ndarray],
input_prec: str,
) -> Dict[str, np.ndarray]:
"""Convert the model output to the input prec."""
if (

Check warning on line 275 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L275

Added line #L275 was not covered by tests
input_prec
!= self.reverse_precision_dict[self.global_np_float_precision]
):
pp = self.precision_dict[input_prec]
for kk, vv in model_ret.items():
model_ret[kk] = vv.astype(pp) if vv is not None else None
return model_ret

Check warning on line 282 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L279-L282

Added lines #L279 - L282 were not covered by tests

def format_nlist(
self,
extended_coord: np.ndarray,
Expand Down
95 changes: 87 additions & 8 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Dict,
List,
Optional,
Tuple,
)

import torch
Expand All @@ -17,6 +18,11 @@
communicate_extended_output,
fit_output_to_model_output,
)
from deepmd.pt.utils.env import (

Check warning on line 21 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L21

Added line #L21 was not covered by tests
GLOBAL_PT_FLOAT_PRECISION,
PRECISION_DICT,
RESERVED_PRECISON_DICT,
)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
nlist_distinguish_types,
Expand Down Expand Up @@ -56,6 +62,9 @@
*args,
**kwargs,
)
self.precision_dict = PRECISION_DICT
self.reverse_precision_dict = RESERVED_PRECISON_DICT
self.global_pt_float_precision = GLOBAL_PT_FLOAT_PRECISION

Check warning on line 67 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L65-L67

Added lines #L65 - L67 were not covered by tests

def model_output_def(self):
"""Get the output def for the model."""
Expand Down Expand Up @@ -115,34 +124,39 @@
The keys are defined by the `ModelOutputDef`.

"""
cc, bb, fp, ap, input_prec = self.input_type_cast(

Check warning on line 127 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L127

Added line #L127 was not covered by tests
coord, box=box, fparam=fparam, aparam=aparam
)
del coord, box, fparam, aparam

Check warning on line 130 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L130

Added line #L130 was not covered by tests
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
cc,
atype,
self.get_rcut(),
self.get_sel(),
mixed_types=self.mixed_types(),
box=box,
box=bb,
)
model_predict_lower = self.forward_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
do_atomic_virial=do_atomic_virial,
fparam=fparam,
aparam=aparam,
fparam=fp,
aparam=ap,
)
model_predict = communicate_extended_output(
model_predict_lower,
self.model_output_def(),
mapping,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)

Check warning on line 159 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L159

Added line #L159 was not covered by tests
return model_predict

def forward_common_lower(
Expand Down Expand Up @@ -186,22 +200,87 @@
nframes, nall = extended_atype.shape[:2]
extended_coord = extended_coord.view(nframes, -1, 3)
nlist = self.format_nlist(extended_coord, extended_atype, nlist)
cc_ext, _, fp, ap, input_prec = self.input_type_cast(

Check warning on line 203 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L203

Added line #L203 was not covered by tests
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam

Check warning on line 206 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L206

Added line #L206 was not covered by tests
atomic_ret = self.forward_atomic(
extended_coord,
cc_ext,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
fparam=fp,
aparam=ap,
)
model_predict = fit_output_to_model_output(
atomic_ret,
self.fitting_output_def(),
extended_coord,
cc_ext,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)

Check warning on line 221 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L221

Added line #L221 was not covered by tests
return model_predict

def input_type_cast(

Check warning on line 224 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L224

Added line #L224 was not covered by tests
self,
coord: torch.Tensor,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
) -> Tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
str,
]:
"""Cast the input data to global float type."""
input_prec = self.reverse_precision_dict[coord.dtype]

Check warning on line 238 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L238

Added line #L238 was not covered by tests
###
### type checking would not pass jit, convert to coord prec anyway
###
# for vv, kk in zip([fparam, aparam], ["frame", "atomic"]):
# if vv is not None and self.reverse_precision_dict[vv.dtype] != input_prec:
# log.warning(
# f"type of {kk} parameter {self.reverse_precision_dict[vv.dtype]}"
# " does not match"
# f" that of the coordinate {input_prec}"
Comment on lines +246 to +251

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
# )
_lst: List[Optional[torch.Tensor]] = [

Check warning on line 249 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L249

Added line #L249 was not covered by tests
vv.to(coord.dtype) if vv is not None else None
for vv in [box, fparam, aparam]
]
box, fparam, aparam = _lst
if (

Check warning on line 254 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L253-L254

Added lines #L253 - L254 were not covered by tests
input_prec
== self.reverse_precision_dict[self.global_pt_float_precision]
):
return coord, box, fparam, aparam, input_prec

Check warning on line 258 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L258

Added line #L258 was not covered by tests
else:
pp = self.global_pt_float_precision
return (

Check warning on line 261 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L260-L261

Added lines #L260 - L261 were not covered by tests
coord.to(pp),
box.to(pp) if box is not None else None,
fparam.to(pp) if fparam is not None else None,
aparam.to(pp) if aparam is not None else None,
input_prec,
)

def output_type_cast(

Check warning on line 269 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L269

Added line #L269 was not covered by tests
self,
model_ret: Dict[str, torch.Tensor],
input_prec: str,
) -> Dict[str, torch.Tensor]:
"""Convert the model output to the input prec."""
if (

Check warning on line 275 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L275

Added line #L275 was not covered by tests
input_prec
!= self.reverse_precision_dict[self.global_pt_float_precision]
):
pp = self.precision_dict[input_prec]
for kk, vv in model_ret.items():
model_ret[kk] = vv.to(pp) if vv is not None else None
njzjz marked this conversation as resolved.
Show resolved Hide resolved
return model_ret

Check warning on line 282 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L279-L282

Added lines #L279 - L282 were not covered by tests

@torch.jit.export
def format_nlist(
self,
Expand Down
14 changes: 14 additions & 0 deletions deepmd/pt/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
exclude_types: List[int] = [],
):
super().__init__()
self.reinit(ntypes, exclude_types)

Check warning on line 25 in deepmd/pt/utils/exclude_mask.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/exclude_mask.py#L25

Added line #L25 was not covered by tests

def reinit(

Check warning on line 27 in deepmd/pt/utils/exclude_mask.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/exclude_mask.py#L27

Added line #L27 was not covered by tests
self,
ntypes: int,
exclude_types: List[int] = [],
):
self.ntypes = ntypes
self.exclude_types = exclude_types
self.type_mask = np.array(
Expand Down Expand Up @@ -62,6 +69,13 @@
exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
self.reinit(ntypes, exclude_types)

Check warning on line 72 in deepmd/pt/utils/exclude_mask.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/exclude_mask.py#L72

Added line #L72 was not covered by tests

def reinit(

Check warning on line 74 in deepmd/pt/utils/exclude_mask.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/exclude_mask.py#L74

Added line #L74 was not covered by tests
self,
ntypes: int,
exclude_types: List[Tuple[int, int]] = [],
):
self.ntypes = ntypes
self._exclude_types: Set[Tuple[int, int]] = set()
for tt in exclude_types:
Expand Down
Loading