Skip to content

Commit

Permalink
feat(tf): pass rcut to PairTab (#3794)
Browse files Browse the repository at this point in the history
Fix #1895. Fix #1973.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced model configurations with an optional cutoff radius (`rcut`)
for tabulated potentials.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored May 23, 2024
1 parent dd97895 commit ac892ce
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 18 deletions.
13 changes: 13 additions & 0 deletions deepmd/tf/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from deepmd.tf.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.tf.utils.pair_tab import (
PairTab,
)
from deepmd.tf.utils.spin import (
Spin,
)
Expand Down Expand Up @@ -106,6 +109,16 @@ def __init__(
self.numb_fparam = self.fitting.get_numb_fparam()
self.numb_aparam = self.fitting.get_numb_aparam()

self.srtab_name = use_srtab
if self.srtab_name is not None:
self.srtab = PairTab(self.srtab_name, rcut=self.get_rcut())
self.smin_alpha = smin_alpha
self.sw_rmin = sw_rmin
self.sw_rmax = sw_rmax
self.srtab_add_bias = srtab_add_bias
else:
self.srtab = None

def get_rcut(self):
return self.rcut

Expand Down
17 changes: 0 additions & 17 deletions deepmd/tf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@
from deepmd.tf.utils.graph import (
load_graph_def,
)
from deepmd.tf.utils.pair_tab import (
PairTab,
)
from deepmd.tf.utils.spin import (
Spin,
)
Expand Down Expand Up @@ -116,11 +113,6 @@ def __init__(
data_stat_nbatch: int = 10,
data_bias_nsample: int = 10,
data_stat_protect: float = 1e-2,
use_srtab: Optional[str] = None,
smin_alpha: Optional[float] = None,
sw_rmin: Optional[float] = None,
sw_rmax: Optional[float] = None,
srtab_add_bias: bool = True,
spin: Optional[Spin] = None,
compress: Optional[dict] = None,
**kwargs,
Expand All @@ -142,15 +134,6 @@ def __init__(
self.data_stat_nbatch = data_stat_nbatch
self.data_bias_nsample = data_bias_nsample
self.data_stat_protect = data_stat_protect
self.srtab_name = use_srtab
if self.srtab_name is not None:
self.srtab = PairTab(self.srtab_name)
self.smin_alpha = smin_alpha
self.sw_rmin = sw_rmin
self.sw_rmax = sw_rmax
self.srtab_add_bias = srtab_add_bias
else:
self.srtab = None

def get_type_map(self) -> list:
"""Get the type map."""
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/pairtab.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
):
super().__init__()
self.tab_file = tab_file
self.tab = PairTab(self.tab_file)
self.tab = PairTab(self.tab_file, rcut=rcut)
self.ntypes = self.tab.ntypes
self.rcut = rcut
if isinstance(sel, int):
Expand Down
4 changes: 4 additions & 0 deletions deepmd/utils/pair_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class PairTab:
The second to the last columes are energies for pairs of certain types.
For example we have two atom types, 0 and 1.
The columes from 2nd to 4th are for 0-0, 0-1 and 1-1 correspondingly.
rcut : float, optional
cutoff raduis for the tabulated potential
"""

def __init__(self, filename: str, rcut: Optional[float] = None) -> None:
Expand All @@ -49,6 +51,8 @@ def reinit(self, filename: str, rcut: Optional[float] = None) -> None:
The second to the last columes are energies for pairs of certain types.
For example we have two atom types, 0 and 1.
The columes from 2nd to 4th are for 0-0, 0-1 and 1-1 correspondingly.
rcut : float, optional
cutoff raduis for the tabulated potential
"""
if filename is None:
self.tab_info, self.tab_data = None, None
Expand Down

0 comments on commit ac892ce

Please sign in to comment.