Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: improve type anotations in deepmd.infer #3792

Merged
merged 1 commit into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def __init__(
self,
model_file: str,
output_def: ModelOutputDef,
*args: List[Any],
*args: Any,
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
):
self.output_def = output_def
self.model_path = model_file
Expand Down Expand Up @@ -161,12 +161,12 @@ def get_ntypes_spin(self):
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Dict[str, np.ndarray]:
"""Evaluate the energy, force and virial by using this DP.

Expand Down
9 changes: 1 addition & 8 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,7 @@
)

if TYPE_CHECKING:
from deepmd.tf.infer import (
DeepDipole,
DeepDOS,
DeepPolar,
DeepPot,
DeepWFC,
)
from deepmd.tf.infer.deep_tensor import (
from deepmd.infer.deep_tensor import (

Check warning on line 49 in deepmd/entrypoints/test.py

View check run for this annotation

Codecov / codecov/patch

deepmd/entrypoints/test.py#L49

Added line #L49 was not covered by tests
DeepTensor,
)

Expand Down
3 changes: 1 addition & 2 deletions deepmd/infer/deep_dos.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Dict,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -70,7 +69,7 @@ def eval(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Tuple[np.ndarray, ...]:
"""Evaluate energy, force, and virial. If atomic is True,
also return atomic energy and atomic virial.
Expand Down
23 changes: 12 additions & 11 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
List,
Optional,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -82,10 +83,10 @@ def __init__(
self,
model_file: str,
output_def: ModelOutputDef,
*args: List[Any],
*args: Any,
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
pass

Expand All @@ -99,12 +100,12 @@ def __new__(cls, model_file: str, *args, **kwargs):
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Dict[str, np.ndarray]:
"""Evaluate the energy, force and virial by using this DP.

Expand Down Expand Up @@ -166,13 +167,13 @@ def get_dim_aparam(self) -> int:
def eval_descriptor(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
efield: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> np.ndarray:
"""Evaluate descriptors by using this DP.

Expand Down Expand Up @@ -246,11 +247,11 @@ def _check_mixed_types(self, atom_types: np.ndarray) -> bool:
# assume mixed_types if there are virtual types, even when
# the atom types of all frames are the same
return False
return np.all(np.equal(atom_types, atom_types[0]))
return np.all(np.equal(atom_types, atom_types[0])).item()

@property
@abstractmethod
def model_type(self) -> "DeepEval":
def model_type(self) -> Type["DeepEval"]:
"""The the evaluator of the model type."""

@abstractmethod
Expand Down Expand Up @@ -316,10 +317,10 @@ def __new__(cls, model_file: str, *args, **kwargs):
def __init__(
self,
model_file: str,
*args: List[Any],
*args: Any,
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
self.deep_eval = DeepEvalBackend(
model_file,
Expand Down Expand Up @@ -387,7 +388,7 @@ def eval_descriptor(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> np.ndarray:
"""Evaluate descriptors by using this DP.

Expand Down
2 changes: 1 addition & 1 deletion deepmd/infer/deep_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def eval(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: dict,
**kwargs,
) -> np.ndarray:
"""Evaluate the model.

Expand Down
47 changes: 45 additions & 2 deletions deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
overload,
)

import numpy as np
Expand Down Expand Up @@ -89,6 +90,48 @@
)
)

@overload
def eval(
self,
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: Union[List[int], np.ndarray],
atomic: Literal[True],
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
mixed_type: bool,
**kwargs: Any,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
pass

Check warning on line 105 in deepmd/infer/deep_pot.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_pot.py#L105

Added line #L105 was not covered by tests

@overload
def eval(
self,
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: Union[List[int], np.ndarray],
atomic: Literal[False],
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
mixed_type: bool,
**kwargs: Any,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
pass

Check warning on line 119 in deepmd/infer/deep_pot.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_pot.py#L119

Added line #L119 was not covered by tests

@overload
def eval(
self,
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: Union[List[int], np.ndarray],
atomic: bool,
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
mixed_type: bool,
**kwargs: Any,
) -> Tuple[np.ndarray, ...]:
pass

Check warning on line 133 in deepmd/infer/deep_pot.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_pot.py#L133

Added line #L133 was not covered by tests

def eval(
self,
coords: np.ndarray,
Expand All @@ -98,7 +141,7 @@
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Tuple[np.ndarray, ...]:
"""Evaluate energy, force, and virial. If atomic is True,
also return atomic energy and atomic virial.
Expand Down
14 changes: 11 additions & 3 deletions deepmd/infer/model_devi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
relative: Optional[float] = None,
atomic: Literal[False] = False,
atomic: Literal[False] = ...,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: ...


Expand All @@ -37,11 +37,19 @@
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
relative: Optional[float] = None,
*,
atomic: Literal[True],
atomic: Literal[True] = ...,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ...


@overload
def calc_model_devi_f(
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
relative: Optional[float] = None,
atomic: bool = False,
) -> Tuple[np.ndarray, ...]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


def calc_model_devi_f(
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
Expand Down
11 changes: 6 additions & 5 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
List,
Optional,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -87,11 +88,11 @@ def __init__(
self,
model_file: str,
output_def: ModelOutputDef,
*args: List[Any],
*args: Any,
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
head: Optional[str] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
):
self.output_def = output_def
self.model_path = model_file
Expand Down Expand Up @@ -165,7 +166,7 @@ def get_dim_aparam(self) -> int:
return self.dp.model["Default"].get_dim_aparam()

@property
def model_type(self) -> "DeepEvalWrapper":
def model_type(self) -> Type["DeepEvalWrapper"]:
"""The the evaluator of the model type."""
model_output_type = self.dp.model["Default"].model_output_type()
if "energy" in model_output_type:
Expand Down Expand Up @@ -211,12 +212,12 @@ def get_has_spin(self):
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Dict[str, np.ndarray]:
"""Evaluate the energy, force and virial by using this DP.

Expand Down
11 changes: 6 additions & 5 deletions deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
List,
Optional,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -262,7 +263,7 @@ def _init_attr(self):

@property
@lru_cache(maxsize=None)
def model_type(self) -> "DeepEvalWrapper":
def model_type(self) -> Type["DeepEvalWrapper"]:
"""Get type of model.

:type:str
Expand Down Expand Up @@ -693,13 +694,13 @@ def _get_natoms_and_nframes(
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
efield: Optional[np.ndarray] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Dict[str, np.ndarray]:
"""Evaluate the energy, force and virial by using this DP.

Expand Down Expand Up @@ -1023,7 +1024,7 @@ def _get_output_shape(self, odef, nframes, natoms):
def eval_descriptor(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
Expand Down Expand Up @@ -1080,7 +1081,7 @@ def eval_descriptor(
def _eval_descriptor_inner(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
Expand Down
4 changes: 2 additions & 2 deletions deepmd/tf/infer/deep_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def get_dim_aparam(self) -> int:
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: List[int],
atomic: bool = True,
fparam: Optional[np.ndarray] = None,
Expand Down Expand Up @@ -276,7 +276,7 @@ def eval(
def eval_full(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: List[int],
atomic: bool = False,
fparam: Optional[np.array] = None,
Expand Down