Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: numpy pairtab model #3212

Merged
merged 58 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
47cff4b
feat: add pair table model to pytorch
Jan 28, 2024
04b6f57
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2024
eb59d87
fix: typo
Jan 28, 2024
b7cbbd5
fix: typo
Jan 28, 2024
a1a76bb
Merge branch 'devel' into devel
anyangml Jan 28, 2024
84767f3
fix: update ruct extrapolation
Jan 28, 2024
8fee8fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2024
ff08515
fix: update allclose precision
Jan 28, 2024
f4b3720
Merge branch 'devel' into devel
anyangml Jan 29, 2024
451916e
Merge branch 'devel' into devel
anyangml Jan 29, 2024
0968eaa
Merge branch 'devel' into devel
anyangml Jan 29, 2024
6b0559e
Merge branch 'devel' into devel
anyangml Jan 29, 2024
8cbb98c
chore: refactor common method to PairTab
Jan 29, 2024
a08092c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2024
d3090b9
fix: update unit tests
Jan 29, 2024
daf2fc6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2024
399c278
fix: revert padding zero mask change
Jan 29, 2024
59abe43
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2024
8f1cdc8
Merge branch 'devel' into devel
anyangml Jan 30, 2024
88936cc
Merge branch 'devel' into devel
anyangml Jan 30, 2024
1c4ee0d
feat: redo extrapolation with cubic spline for smoothness
Jan 30, 2024
5793828
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2024
27f3559
Merge branch 'devel' into devel
anyangml Jan 30, 2024
92dec18
chore: refactor _make_data in PairTab
Jan 30, 2024
bc04359
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2024
4433035
chore: move file
Jan 30, 2024
2ba0318
Merge branch 'devel' into devel
anyangml Jan 30, 2024
f2c40e6
Merge branch 'devel' into devel
anyangml Jan 31, 2024
4851a0a
chore: refactor extrapolation code
Jan 31, 2024
ddbe7db
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2024
29d95db
Merge branch 'devel' into devel
anyangml Jan 31, 2024
365c20d
feat: add zbl weighted model
Feb 1, 2024
fb4ae7d
Merge branch 'deepmodeling:devel' into devel
anyangml Feb 1, 2024
e423e68
feat: add serialize and deserialize to pt pairtab
Feb 1, 2024
39da7ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 1, 2024
74002ed
chore: remove irrelevant files
Feb 1, 2024
0ce23f4
feat: add numpy version
Feb 1, 2024
5190903
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 1, 2024
f5541d1
fix: redo pairtab pt serialization
Feb 2, 2024
de3296a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 2, 2024
43fcea7
feat: test pairtabmodel numpy version
Feb 2, 2024
9716144
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 2, 2024
b09f43b
fix: precommit
Feb 2, 2024
05d2750
fix: precommit
Feb 2, 2024
263b6dc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 2, 2024
35232c2
chore: refactor code
Feb 2, 2024
c72bdb9
fix: at @variables key to serialize
Feb 2, 2024
f1e5d72
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 2, 2024
6950e98
fix: import
Feb 2, 2024
5083865
Merge remote-tracking branch 'upstream/devel' into feat/numpy_pairtab…
Feb 2, 2024
97ccfcc
Merge branch 'devel' into feat/numpy_pairtab_model
anyangml Feb 2, 2024
39da832
fix: rename method
Feb 2, 2024
1a59917
fix: import
anyangml Feb 2, 2024
a59a930
Merge branch 'devel' into feat/numpy_pairtab_model
anyangml Feb 3, 2024
97c6a88
Merge branch 'devel' into feat/numpy_pairtab_model
anyangml Feb 4, 2024
a6c637a
fix: change input output shape and move UTs
anyangml Feb 4, 2024
3b0faa4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2024
ec04b5d
fix: array type
anyangml Feb 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 295 additions & 0 deletions deepmd/dpmodel/model/pair_tab_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (

Check warning on line 2 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L2

Added line #L2 was not covered by tests
Dict,
List,
Optional,
Union,
)

import numpy as np

Check warning on line 9 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L9

Added line #L9 was not covered by tests

from deepmd.dpmodel.output_def import (

Check warning on line 11 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L11

Added line #L11 was not covered by tests
FittingOutputDef,
OutputVariableDef,
)
from deepmd.utils.pair_tab import (

Check warning on line 15 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L15

Added line #L15 was not covered by tests
PairTab,
)

from .base_atomic_model import (

Check warning on line 19 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L19

Added line #L19 was not covered by tests
BaseAtomicModel,
)


class PairTabModel(BaseAtomicModel):

Check warning on line 24 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L24

Added line #L24 was not covered by tests
"""Pairwise tabulation energy model.

This model can be used to tabulate the pairwise energy between atoms for either
short-range or long-range interactions, such as D3, LJ, ZBL, etc. It should not
be used alone, but rather as one submodel of a linear (sum) model, such as
DP+D3.

Do not put the model on the first model of a linear model, since the linear
model fetches the type map from the first model.

At this moment, the model does not smooth the energy at the cutoff radius, so
one needs to make sure the energy has been smoothed to zero.

Parameters
----------
tab_file : str
The path to the tabulation file.
rcut : float
The cutoff radius.
sel : int or list[int]
The maxmum number of atoms in the cut-off radius.
"""

def __init__(

Check warning on line 48 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L48

Added line #L48 was not covered by tests
self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs
):
super().__init__()
self.tab_file = tab_file
self.rcut = rcut

Check warning on line 53 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L51-L53

Added lines #L51 - L53 were not covered by tests

self.tab = PairTab(self.tab_file, rcut=rcut)

Check warning on line 55 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L55

Added line #L55 was not covered by tests

if self.tab_file is not None:
self.tab_info, self.tab_data = self.tab.get()

Check warning on line 58 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L57-L58

Added lines #L57 - L58 were not covered by tests
else:
self.tab_info, self.tab_data = None, None

Check warning on line 60 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L60

Added line #L60 was not covered by tests

if isinstance(sel, int):
self.sel = sel
elif isinstance(sel, list):
self.sel = sum(sel)

Check warning on line 65 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L62-L65

Added lines #L62 - L65 were not covered by tests
else:
raise TypeError("sel must be int or list[int]")

Check warning on line 67 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L67

Added line #L67 was not covered by tests

def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(

Check warning on line 70 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L69-L70

Added lines #L69 - L70 were not covered by tests
[
OutputVariableDef(
name="energy", shape=[1], reduciable=True, differentiable=True
)
]
)

def get_rcut(self) -> float:
return self.rcut

Check warning on line 79 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L78-L79

Added lines #L78 - L79 were not covered by tests

def get_sel(self) -> int:
return self.sel

Check warning on line 82 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L81-L82

Added lines #L81 - L82 were not covered by tests

def distinguish_types(self) -> bool:

Check warning on line 84 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L84

Added line #L84 was not covered by tests
# to match DPA1 and DPA2.
return False

Check warning on line 86 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L86

Added line #L86 was not covered by tests

def serialize(self) -> dict:
return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel}

Check warning on line 89 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L88-L89

Added lines #L88 - L89 were not covered by tests

@classmethod
def deserialize(cls, data) -> "PairTabModel":
rcut = data["rcut"]
sel = data["sel"]
tab = PairTab.deserialize(data["tab"])
tab_model = cls(None, rcut, sel)
tab_model.tab = tab
tab_model.tab_info = tab_model.tab.tab_info
tab_model.tab_data = tab_model.tab.tab_data
return tab_model

Check warning on line 100 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L91-L100

Added lines #L91 - L100 were not covered by tests

def forward_atomic(

Check warning on line 102 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L102

Added line #L102 was not covered by tests
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[np.array] = None,
anyangml marked this conversation as resolved.
Show resolved Hide resolved
do_atomic_virial: bool = False,
) -> Dict[str, np.array]:
self.nframes, self.nloc, self.nnei = nlist.shape

Check warning on line 110 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L110

Added line #L110 was not covered by tests

# this will mask all -1 in the nlist
masked_nlist = np.clip(nlist, 0, None)

Check warning on line 113 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L113

Added line #L113 was not covered by tests

atype = extended_atype[:, : self.nloc] # (nframes, nloc)
pairwise_dr = self._get_pairwise_dist(

Check warning on line 116 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L115-L116

Added lines #L115 - L116 were not covered by tests
extended_coord
) # (nframes, nall, nall, 3)
pairwise_rr = np.sqrt(

Check warning on line 119 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L119

Added line #L119 was not covered by tests
np.sum(np.power(pairwise_dr, 2), axis=-1)
) # (nframes, nall, nall)
self.tab_data = self.tab_data.reshape(

Check warning on line 122 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L122

Added line #L122 was not covered by tests
self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4
)

# (nframes, nloc, nnei)
j_type = extended_atype[

Check warning on line 127 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L127

Added line #L127 was not covered by tests
np.arange(extended_atype.shape[0])[:, None, None], masked_nlist
]

# slice rr to get (nframes, nloc, nnei)
rr = np.take_along_axis(pairwise_rr[:, : self.nloc, :], masked_nlist, 2)
raw_atomic_energy = self._pair_tabulated_inter(nlist, atype, j_type, rr)
atomic_energy = 0.5 * np.sum(

Check warning on line 134 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L132-L134

Added lines #L132 - L134 were not covered by tests
np.where(nlist != -1, raw_atomic_energy, np.zeros_like(raw_atomic_energy)),
axis=-1,
)

return {"energy": atomic_energy}

Check warning on line 139 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L139

Added line #L139 was not covered by tests

def _pair_tabulated_inter(

Check warning on line 141 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L141

Added line #L141 was not covered by tests
self,
nlist: np.array,
i_type: np.array,
j_type: np.array,
rr: np.array,
) -> np.array:
"""Pairwise tabulated energy.

Parameters
----------
nlist : np.array
The unmasked neighbour list. (nframes, nloc)
i_type : np.array
The integer representation of atom type for all local atoms for all frames. (nframes, nloc)
j_type : np.array
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei)
rr : np.array
The salar distance vector between two atoms. (nframes, nloc, nnei)

Returns
-------
np.array
The masked atomic energy for all local atoms for all frames. (nframes, nloc, nnei)

Raises
------
Exception
If the distance is beyond the table.

Notes
-----
This function is used to calculate the pairwise energy between two atoms.
It uses a table containing cubic spline coefficients calculated in PairTab.
"""
rmin = self.tab_info[0]
hh = self.tab_info[1]
hi = 1.0 / hh

Check warning on line 178 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L176-L178

Added lines #L176 - L178 were not covered by tests

self.nspline = int(self.tab_info[2] + 0.1)

Check warning on line 180 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L180

Added line #L180 was not covered by tests

uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei)

Check warning on line 182 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L182

Added line #L182 was not covered by tests

# if nnei of atom 0 has -1 in the nlist, uu would be 0.
# this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms.
uu = np.where(nlist != -1, uu, self.nspline + 1)

Check warning on line 186 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L186

Added line #L186 was not covered by tests

if np.any(uu < 0):
raise Exception("coord go beyond table lower boundary")

Check warning on line 189 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L188-L189

Added lines #L188 - L189 were not covered by tests

idx = uu.astype(int)

Check warning on line 191 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L191

Added line #L191 was not covered by tests

uu -= idx
table_coef = self._extract_spline_coefficient(

Check warning on line 194 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L193-L194

Added lines #L193 - L194 were not covered by tests
i_type, j_type, idx, self.tab_data, self.nspline
)
table_coef = table_coef.reshape(self.nframes, self.nloc, self.nnei, 4)
ener = self._calcualte_ener(table_coef, uu)

Check warning on line 198 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L197-L198

Added lines #L197 - L198 were not covered by tests
# here we need to overwrite energy to zero at rcut and beyond.
mask_beyond_rcut = rr >= self.rcut

Check warning on line 200 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L200

Added line #L200 was not covered by tests
# also overwrite values beyond extrapolation to zero
extrapolation_mask = rr >= self.tab.rmin + self.nspline * self.tab.hh
ener[mask_beyond_rcut] = 0
ener[extrapolation_mask] = 0

Check warning on line 204 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L202-L204

Added lines #L202 - L204 were not covered by tests

return ener

Check warning on line 206 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L206

Added line #L206 was not covered by tests

@staticmethod
def _get_pairwise_dist(coords: np.array) -> np.array:

Check warning on line 209 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L208-L209

Added lines #L208 - L209 were not covered by tests
"""Get pairwise distance `dr`.

Parameters
----------
coords : np.array
The coordinate of the atoms shape of (nframes * nall * 3).

Returns
-------
np.array
The pairwise distance between the atoms (nframes * nall * nall * 3).
"""
return np.expand_dims(coords, 2) - np.expand_dims(coords, 1)

Check warning on line 222 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L222

Added line #L222 was not covered by tests

@staticmethod
def _extract_spline_coefficient(

Check warning on line 225 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L224-L225

Added lines #L224 - L225 were not covered by tests
i_type: np.array,
j_type: np.array,
idx: np.array,
tab_data: np.array,
anyangml marked this conversation as resolved.
Show resolved Hide resolved
nspline: int,
) -> np.array:
"""Extract the spline coefficient from the table.

Parameters
----------
i_type : np.array
The integer representation of atom type for all local atoms for all frames. (nframes, nloc)
j_type : np.array
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei)
idx : np.array
The index of the spline coefficient. (nframes, nloc, nnei)
tab_data : np.array
The table storing all the spline coefficient. (ntype, ntype, nspline, 4)
nspline : int
The number of splines in the table.

Returns
-------
np.array
The spline coefficient. (nframes, nloc, nnei, 4), shape may be squeezed.
"""
# (nframes, nloc, nnei)
expanded_i_type = np.broadcast_to(

Check warning on line 253 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L253

Added line #L253 was not covered by tests
i_type[:, :, np.newaxis],
(i_type.shape[0], i_type.shape[1], j_type.shape[-1]),
)

# (nframes, nloc, nnei, nspline, 4)
expanded_tab_data = tab_data[expanded_i_type, j_type]

Check warning on line 259 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L259

Added line #L259 was not covered by tests

# (nframes, nloc, nnei, 1, 4)
expanded_idx = np.broadcast_to(

Check warning on line 262 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L262

Added line #L262 was not covered by tests
idx[..., np.newaxis, np.newaxis], (*idx.shape, 1, 4)
)
clipped_indices = np.clip(expanded_idx, 0, nspline - 1).astype(int)

Check warning on line 265 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L265

Added line #L265 was not covered by tests

# (nframes, nloc, nnei, 4)
final_coef = np.squeeze(

Check warning on line 268 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L268

Added line #L268 was not covered by tests
np.take_along_axis(expanded_tab_data, clipped_indices, 3)
)

# when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`.
final_coef[expanded_idx.squeeze() > nspline] = 0
return final_coef

Check warning on line 274 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L273-L274

Added lines #L273 - L274 were not covered by tests

@staticmethod
def _calcualte_ener(coef: np.array, uu: np.array) -> np.array:

Check warning on line 277 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L276-L277

Added lines #L276 - L277 were not covered by tests
"""Calculate energy using spline coeeficients.

Parameters
----------
coef : np.array
The spline coefficients. (nframes, nloc, nnei, 4)
uu : np.array
The atom displancemnt used in interpolation and extrapolation (nframes, nloc, nnei)

Returns
-------
np.array
The atomic energy for all local atoms for all frames. (nframes, nloc, nnei)
"""
a3, a2, a1, a0 = coef[..., 0], coef[..., 1], coef[..., 2], coef[..., 3]
etmp = (a3 * uu + a2) * uu + a1 # this should be elementwise operations.
ener = etmp * uu + a0 # this energy has the extrapolated value when rcut > rmax
return ener

Check warning on line 295 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L292-L295

Added lines #L292 - L295 were not covered by tests
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@
nn,
)

from deepmd.dpmodel import (
from deepmd.model_format import (

Check warning on line 14 in deepmd/pt/model/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab_model.py#L14

Added line #L14 was not covered by tests
FittingOutputDef,
OutputVariableDef,
)
from deepmd.utils.pair_tab import (
PairTab,
)

from .base_atomic_model import (
BaseAtomicModel,
from .atomic_model import (

Check warning on line 22 in deepmd/pt/model/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab_model.py#L22

Added line #L22 was not covered by tests
AtomicModel,
)


class PairTabModel(nn.Module, BaseAtomicModel):
class PairTabModel(nn.Module, AtomicModel):

Check warning on line 27 in deepmd/pt/model/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab_model.py#L27

Added line #L27 was not covered by tests
"""Pairwise tabulation energy model.

This model can be used to tabulate the pairwise energy between atoms for either
Expand Down Expand Up @@ -54,13 +54,19 @@
super().__init__()
self.tab_file = tab_file
self.rcut = rcut

self.tab = PairTab(self.tab_file, rcut=rcut)
self.ntypes = self.tab.ntypes

tab_info, tab_data = self.tab.get() # this returns -> Tuple[np.array, np.array]
self.tab_info = torch.from_numpy(tab_info)
self.tab_data = torch.from_numpy(tab_data)
# handle deserialization with no input file
if self.tab_file is not None:
(

Check warning on line 61 in deepmd/pt/model/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab_model.py#L60-L61

Added lines #L60 - L61 were not covered by tests
tab_info,
tab_data,
) = self.tab.get() # this returns -> Tuple[np.array, np.array]
self.tab_info = torch.from_numpy(tab_info)
self.tab_data = torch.from_numpy(tab_data)

Check warning on line 66 in deepmd/pt/model/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab_model.py#L65-L66

Added lines #L65 - L66 were not covered by tests
else:
self.tab_info = None
self.tab_data = None

Check warning on line 69 in deepmd/pt/model/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab_model.py#L68-L69

Added lines #L68 - L69 were not covered by tests

# self.model_type = "ener"
# self.model_version = MODEL_VERSION ## this shoud be in the parent class
Expand All @@ -72,7 +78,7 @@
else:
raise TypeError("sel must be int or list[int]")

def fitting_output_def(self) -> FittingOutputDef:
def get_fitting_output_def(self) -> FittingOutputDef:

Check warning on line 81 in deepmd/pt/model/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab_model.py#L81

Added line #L81 was not covered by tests
return FittingOutputDef(
[
OutputVariableDef(
Expand All @@ -92,12 +98,18 @@
return False

def serialize(self) -> dict:
# place holder, implemantated in future PR
raise NotImplementedError

def deserialize(cls):
# place holder, implemantated in future PR
raise NotImplementedError
return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel}

Check warning on line 101 in deepmd/pt/model/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab_model.py#L101

Added line #L101 was not covered by tests

@classmethod
def deserialize(cls, data) -> "PairTabModel":
rcut = data["rcut"]
sel = data["sel"]
tab = PairTab.deserialize(data["tab"])
tab_model = cls(None, rcut, sel)
tab_model.tab = tab
tab_model.tab_info = torch.from_numpy(tab_model.tab.tab_info)
tab_model.tab_data = torch.from_numpy(tab_model.tab.tab_data)
return tab_model

Check warning on line 112 in deepmd/pt/model/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab_model.py#L103-L112

Added lines #L103 - L112 were not covered by tests

def forward_atomic(
self,
Expand Down
Loading
Loading