Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cast input and output types at model's interface #3352

Merged
merged 9 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import numpy as np

from deepmd.common import (
from deepmd.env import (
GLOBAL_ENER_FLOAT_PRECISION,
GLOBAL_NP_FLOAT_PRECISION,
)

Expand All @@ -21,6 +22,13 @@
"int64": np.int64,
"default": GLOBAL_NP_FLOAT_PRECISION,
}
RESERVED_PRECISON_DICT = {
np.float16: "float16",
np.float32: "float32",
np.float64: "float64",
np.int32: "int32",
np.int64: "int64",
}
DEFAULT_PRECISION = "float64"


Expand All @@ -35,3 +43,13 @@ def call(self, *args, **kwargs):
def __call__(self, *args, **kwargs):
"""Forward pass in NumPy implementation."""
return self.call(*args, **kwargs)


__all__ = [
"GLOBAL_NP_FLOAT_PRECISION",
"GLOBAL_ENER_FLOAT_PRECISION",
"PRECISION_DICT",
"RESERVED_PRECISON_DICT",
"DEFAULT_PRECISION",
"NativeOP",
]
112 changes: 101 additions & 11 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Optional,
Tuple,
)

import numpy as np

from deepmd.dpmodel.common import (
GLOBAL_ENER_FLOAT_PRECISION,
GLOBAL_NP_FLOAT_PRECISION,
PRECISION_DICT,
RESERVED_PRECISON_DICT,
NativeOP,
)
from deepmd.dpmodel.output_def import (
ModelOutputDef,
OutputVariableCategory,
OutputVariableOperation,
check_operation_applied,
)
from deepmd.dpmodel.utils import (
build_neighbor_list,
Expand Down Expand Up @@ -59,6 +67,10 @@
*args,
**kwargs,
)
self.precision_dict = PRECISION_DICT
self.reverse_precision_dict = RESERVED_PRECISON_DICT
self.global_np_float_precision = GLOBAL_NP_FLOAT_PRECISION
self.global_ener_float_precision = GLOBAL_ENER_FLOAT_PRECISION

Check warning on line 73 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L70-L73

Added lines #L70 - L73 were not covered by tests

def model_output_def(self):
"""Get the output def for the model."""
Expand Down Expand Up @@ -115,15 +127,19 @@

"""
nframes, nloc = atype.shape[:2]
if box is not None:
cc, bb, fp, ap, input_prec = self.input_type_cast(

Check warning on line 130 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L130

Added line #L130 was not covered by tests
coord, box=box, fparam=fparam, aparam=aparam
)
del coord, box, fparam, aparam
if bb is not None:

Check warning on line 134 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L133-L134

Added lines #L133 - L134 were not covered by tests
coord_normalized = normalize_coord(
coord.reshape(nframes, nloc, 3),
box.reshape(nframes, 3, 3),
cc.reshape(nframes, nloc, 3),
bb.reshape(nframes, 3, 3),
)
else:
coord_normalized = coord.copy()
coord_normalized = cc.copy()

Check warning on line 140 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L140

Added line #L140 was not covered by tests
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
coord_normalized, atype, box, self.get_rcut()
coord_normalized, atype, bb, self.get_rcut()
)
nlist = build_neighbor_list(
extended_coord,
Expand All @@ -139,8 +155,8 @@
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
model_predict = communicate_extended_output(
Expand All @@ -149,6 +165,7 @@
mapping,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)

Check warning on line 168 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L168

Added line #L168 was not covered by tests
return model_predict

def call_lower(
Expand Down Expand Up @@ -192,22 +209,95 @@
nframes, nall = extended_atype.shape[:2]
extended_coord = extended_coord.reshape(nframes, -1, 3)
nlist = self.format_nlist(extended_coord, extended_atype, nlist)
cc_ext, _, fp, ap, input_prec = self.input_type_cast(

Check warning on line 212 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L212

Added line #L212 was not covered by tests
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam

Check warning on line 215 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L215

Added line #L215 was not covered by tests
atomic_ret = self.forward_atomic(
extended_coord,
cc_ext,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
fparam=fp,
aparam=ap,
)
model_predict = fit_output_to_model_output(
atomic_ret,
self.fitting_output_def(),
extended_coord,
cc_ext,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)

Check warning on line 230 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L230

Added line #L230 was not covered by tests
return model_predict

def input_type_cast(
self,
coord: np.ndarray,
box: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> Tuple[
np.ndarray,
Optional[np.ndarray],
Optional[np.ndarray],
Optional[np.ndarray],
str,
]:
"""Cast the input data to global float type."""
input_prec = self.reverse_precision_dict[

Check warning on line 247 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L247

Added line #L247 was not covered by tests
self.precision_dict[coord.dtype.name]
]
###
### type checking would not pass jit, convert to coord prec anyway
###
_lst: List[Optional[np.ndarray]] = [

Check warning on line 253 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L253

Added line #L253 was not covered by tests
vv.astype(coord.dtype) if vv is not None else None
for vv in [box, fparam, aparam]
]
box, fparam, aparam = _lst
if (

Check warning on line 258 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L257-L258

Added lines #L257 - L258 were not covered by tests
input_prec
== self.reverse_precision_dict[self.global_np_float_precision]
):
return coord, box, fparam, aparam, input_prec

Check warning on line 262 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L262

Added line #L262 was not covered by tests
else:
pp = self.global_np_float_precision
return (

Check warning on line 265 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L264-L265

Added lines #L264 - L265 were not covered by tests
coord.astype(pp),
box.astype(pp) if box is not None else None,
fparam.astype(pp) if fparam is not None else None,
aparam.astype(pp) if aparam is not None else None,
input_prec,
)

def output_type_cast(
self,
model_ret: Dict[str, np.ndarray],
input_prec: str,
) -> Dict[str, np.ndarray]:
"""Convert the model output to the input prec."""
do_cast = (

Check warning on line 279 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L279

Added line #L279 was not covered by tests
input_prec
!= self.reverse_precision_dict[self.global_np_float_precision]
)
pp = self.precision_dict[input_prec]
odef = self.model_output_def()
for kk in odef.keys():
if kk not in model_ret.keys():

Check warning on line 286 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L283-L286

Added lines #L283 - L286 were not covered by tests
# do not return energy_derv_c if not do_atomic_virial
continue
if check_operation_applied(odef[kk], OutputVariableOperation.REDU):
model_ret[kk] = (

Check warning on line 290 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L288-L290

Added lines #L288 - L290 were not covered by tests
model_ret[kk].astype(self.global_ener_float_precision)
if model_ret[kk] is not None
else None
)
elif do_cast:
model_ret[kk] = (

Check warning on line 296 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L295-L296

Added lines #L295 - L296 were not covered by tests
model_ret[kk].astype(pp) if model_ret[kk] is not None else None
)
return model_ret

Check warning on line 299 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L299

Added line #L299 was not covered by tests

def format_nlist(
self,
extended_coord: np.ndarray,
Expand Down
8 changes: 7 additions & 1 deletion deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import numpy as np

from deepmd.dpmodel.common import (
GLOBAL_ENER_FLOAT_PRECISION,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
ModelOutputDef,
Expand All @@ -30,7 +33,10 @@
atom_axis = -(len(shap) + 1)
if vdef.reduciable:
kk_redu = get_reduce_name(kk)
model_ret[kk_redu] = np.sum(vv, axis=atom_axis)
# cast to energy prec brefore reduction
model_ret[kk_redu] = np.sum(

Check warning on line 37 in deepmd/dpmodel/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/transform_output.py#L37

Added line #L37 was not covered by tests
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
)
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# name-holders
Expand Down
Loading