Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: infer model type from ModelOutputDef #3250

Merged
merged 1 commit into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,26 @@
OutputVariableCategory,
OutputVariableDef,
)
from deepmd.infer.deep_dipole import (
DeepDipole,
)
Comment on lines +21 to +23

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.infer.deep_dipole
begins an import cycle.
from deepmd.infer.deep_dos import (
DeepDOS,
)
Comment on lines +24 to +26

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.infer.deep_dos
begins an import cycle.
from deepmd.infer.deep_eval import DeepEval as DeepEvalWrapper

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.infer.deep_eval
begins an import cycle.
from deepmd.infer.deep_eval import (
DeepEvalBackend,
)
from deepmd.infer.deep_polar import (
DeepGlobalPolar,
DeepPolar,
)
Comment on lines +31 to +34

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.infer.deep_polar
begins an import cycle.
from deepmd.infer.deep_pot import (
DeepPot,
)
from deepmd.infer.deep_wfc import (
DeepWFC,
)
Comment on lines +38 to +40

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.infer.deep_wfc
begins an import cycle.
from deepmd.pt.model.model import (
get_model,
)
Expand All @@ -44,8 +58,6 @@
if TYPE_CHECKING:
import ase.neighborlist

from deepmd.infer.deep_eval import DeepEval as DeepEvalWrapper


class DeepEval(DeepEvalBackend):
"""PyTorch backend implementaion of DeepEval.
Expand Down Expand Up @@ -127,7 +139,22 @@
@property
def model_type(self) -> "DeepEvalWrapper":
"""The the evaluator of the model type."""
return DeepPot
output_def = self.dp.model["Default"].model_output_def()
var_defs = output_def.var_defs
if "energy" in var_defs:
return DeepPot
elif "dos" in var_defs:
return DeepDOS
elif "dipole" in var_defs:
return DeepDipole
elif "polar" in var_defs:
return DeepPolar
elif "global_polar" in var_defs:
return DeepGlobalPolar
elif "wfc" in var_defs:
return DeepWFC

Check warning on line 155 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L142-L155

Added lines #L142 - L155 were not covered by tests
else:
raise RuntimeError("Unknown model type")

Check warning on line 157 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L157

Added line #L157 was not covered by tests

def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
**kwargs,
)

@torch.jit.export
def model_output_def(self):
"""Get the output def for the model."""
return ModelOutputDef(self.fitting_output_def())
Expand Down