Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pairwise tabulation as an independent model #3101

Merged
merged 5 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ A full [document](doc/train/train-input-auto.rst) on options in the training inp
- [Deep potential long-range](doc/model/dplr.md)
- [Deep Potential - Range Correction (DPRc)](doc/model/dprc.md)
- [Linear model](doc/model/linear.md)
- [Interpolation with a pairwise potential](doc/model/pairtab.md)
- [Interpolation or combination with a pairwise potential](doc/model/pairtab.md)
- [Training](doc/train/index.md)
- [Training a model](doc/train/training.md)
- [Advanced options](doc/train/training-advanced.md)
Expand Down
5 changes: 5 additions & 0 deletions deepmd/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@
from deepmd.model.multi import (
MultiModel,
)
from deepmd.model.pairtab import (
PairTabModel,
)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
from deepmd.model.pairwise_dprc import (
PairwiseDPRc,
)
Expand All @@ -112,6 +115,8 @@
return FrozenModel
elif model_type == "linear_ener":
return LinearEnergyModel
elif model_type == "pairtab":
return PairTabModel
else:
raise ValueError(f"unknown model type: {model_type}")

Expand Down
288 changes: 288 additions & 0 deletions deepmd/model/pairtab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from enum import (
Enum,
)
from typing import (
List,
Optional,
Union,
)

import numpy as np

from deepmd.env import (
GLOBAL_TF_FLOAT_PRECISION,
MODEL_VERSION,
global_cvt_2_ener_float,
op_module,
tf,
)
from deepmd.fit.fitting import (
Fitting,
)
from deepmd.loss.loss import (
Loss,
)
from deepmd.model.model import (
Model,
)
Comment on lines +26 to +28

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.model.model
begins an import cycle.
from deepmd.utils.pair_tab import (
PairTab,
)


class PairTabModel(Model):
"""Pairwise tabulation energy model.

This model can be used to tabulate the pairwise energy between atoms for either
short-range or long-range interactions, such as D3, LJ, ZBL, etc. It should not
be used alone, but rather as one submodel of a linear (sum) model, such as
DP+D3.

Do not put the model on the first model of a linear model, since the linear
model fetches the type map from the first model.

At this moment, the model does not smooth the energy at the cutoff radius, so
one needs to make sure the energy has been smoothed to zero.

Parameters
----------
tab_file : str
The path to the tabulation file.
rcut : float
The cutoff radius
sel : int or list[int]
The maxmum number of atoms in the cut-off radius
"""

model_type = "ener"

def __init__(
self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs
):
super().__init__()
self.tab_file = tab_file
self.tab = PairTab(self.tab_file)
self.ntypes = self.tab.ntypes
self.rcut = rcut
if isinstance(sel, int):
self.sel = sel

Check warning on line 69 in deepmd/model/pairtab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model/pairtab.py#L69

Added line #L69 was not covered by tests
elif isinstance(sel, list):
self.sel = sum(sel)
else:
raise TypeError("sel must be int or list[int]")

Check warning on line 73 in deepmd/model/pairtab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model/pairtab.py#L73

Added line #L73 was not covered by tests

def build(
self,
coord_: tf.Tensor,
atype_: tf.Tensor,
natoms: tf.Tensor,
box: tf.Tensor,
mesh: tf.Tensor,
input_dict: dict,
frz_model: Optional[str] = None,
ckpt_meta: Optional[str] = None,
suffix: str = "",
reuse: Optional[Union[bool, Enum]] = None,
):
"""Build the model.

Parameters
----------
coord_ : tf.Tensor
The coordinates of atoms
atype_ : tf.Tensor
The atom types of atoms
natoms : tf.Tensor
The number of atoms
box : tf.Tensor
The box vectors
mesh : tf.Tensor
The mesh vectors
input_dict : dict
The input dict
frz_model : str, optional
The path to the frozen model
ckpt_meta : str, optional
The path prefix of the checkpoint and meta files
suffix : str, optional
The suffix of the scope
reuse : bool or tf.AUTO_REUSE, optional
Whether to reuse the variables

Returns
-------
dict
The output dict
"""
tab_info, tab_data = self.tab.get()
with tf.variable_scope("model_attr" + suffix, reuse=reuse):
self.tab_info = tf.get_variable(
"t_tab_info",
tab_info.shape,
dtype=tf.float64,
trainable=False,
initializer=tf.constant_initializer(tab_info, dtype=tf.float64),
)
self.tab_data = tf.get_variable(
"t_tab_data",
tab_data.shape,
dtype=tf.float64,
trainable=False,
initializer=tf.constant_initializer(tab_data, dtype=tf.float64),
)
t_tmap = tf.constant(" ".join(self.type_map), name="tmap", dtype=tf.string)

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable t_tmap is not used.
t_mt = tf.constant(self.model_type, name="model_type", dtype=tf.string)

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable t_mt is not used.
t_ver = tf.constant(MODEL_VERSION, name="model_version", dtype=tf.string)

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable t_ver is not used.

with tf.variable_scope("fitting_attr" + suffix, reuse=reuse):
t_dfparam = tf.constant(0, name="dfparam", dtype=tf.int32)

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable t_dfparam is not used.
t_daparam = tf.constant(0, name="daparam", dtype=tf.int32)

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable t_daparam is not used.
with tf.variable_scope("descrpt_attr" + suffix, reuse=reuse):
t_ntypes = tf.constant(self.ntypes, name="ntypes", dtype=tf.int32)

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable t_ntypes is not used.
t_rcut = tf.constant(

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable t_rcut is not used.
self.rcut, name="rcut", dtype=GLOBAL_TF_FLOAT_PRECISION
)
coord = tf.reshape(coord_, [-1, natoms[1] * 3])
atype = tf.reshape(atype_, [-1, natoms[1]])
box = tf.reshape(box, [-1, 9])
# perhaps we need a OP that only outputs rij and nlist
(
_,
_,
rij,
nlist,
_,
_,
) = op_module.prod_env_mat_a_mix(
coord,
atype,
natoms,
box,
mesh,
np.zeros([self.ntypes, self.sel * 4]),
np.ones([self.ntypes, self.sel * 4]),
rcut_a=-1,
rcut_r=self.rcut,
rcut_r_smth=self.rcut,
sel_a=[self.sel],
sel_r=[0],
)
scale = tf.ones([tf.shape(coord)[0], natoms[0]], dtype=tf.float64)
tab_atom_ener, tab_force, tab_atom_virial = op_module.pair_tab(
self.tab_info,
self.tab_data,
atype,
rij,
nlist,
natoms,
scale,
sel_a=[self.sel],
sel_r=[0],
)
energy_raw = tf.reshape(
tab_atom_ener, [-1, natoms[0]], name="o_atom_energy" + suffix
)
energy = tf.reduce_sum(
global_cvt_2_ener_float(energy_raw), axis=1, name="o_energy" + suffix
)
force = tf.reshape(tab_force, [-1, 3 * natoms[1]], name="o_force" + suffix)
virial = tf.reshape(
tf.reduce_sum(tf.reshape(tab_atom_virial, [-1, natoms[1], 9]), axis=1),
[-1, 9],
name="o_virial" + suffix,
)
atom_virial = tf.reshape(
tab_atom_virial, [-1, 9 * natoms[1]], name="o_atom_virial" + suffix
)
model_dict = {}
model_dict["energy"] = energy
model_dict["force"] = force
model_dict["virial"] = virial
model_dict["atom_ener"] = energy_raw
model_dict["atom_virial"] = atom_virial
model_dict["coord"] = coord
model_dict["atype"] = atype

return model_dict

def init_variables(
self,
graph: tf.Graph,
graph_def: tf.GraphDef,
model_type: str = "original_model",
suffix: str = "",
) -> None:
"""Init the embedding net variables with the given frozen model.

Parameters
----------
graph : tf.Graph
The input frozen model graph
graph_def : tf.GraphDef
The input frozen model graph_def
model_type : str
the type of the model
suffix : str
suffix to name scope
"""
# skip. table can be initialized from the file

def get_fitting(self) -> Union[Fitting, dict]:
"""Get the fitting(s)."""
# nothing needs to do
return {}

Check warning on line 234 in deepmd/model/pairtab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model/pairtab.py#L234

Added line #L234 was not covered by tests

def get_loss(self, loss: dict, lr) -> Optional[Union[Loss, dict]]:
"""Get the loss function(s)."""
# nothing nees to do
return

Check warning on line 239 in deepmd/model/pairtab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model/pairtab.py#L239

Added line #L239 was not covered by tests

def get_rcut(self) -> float:
"""Get cutoff radius of the model."""
return self.rcut

Check warning on line 243 in deepmd/model/pairtab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model/pairtab.py#L243

Added line #L243 was not covered by tests

def get_ntypes(self) -> int:
"""Get the number of types."""
return self.ntypes

Check warning on line 247 in deepmd/model/pairtab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model/pairtab.py#L247

Added line #L247 was not covered by tests

def data_stat(self, data: dict):
"""Data staticis."""
# nothing needs to do

def enable_compression(self, suffix: str = "") -> None:
"""Enable compression.

Parameters
----------
suffix : str
suffix to name scope
"""
# nothing needs to do

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict) -> dict:
"""Update the selection and perform neighbor statistics.

Notes
-----
Do not modify the input data without copying it.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class

Returns
-------
dict
The updated local data
"""
from deepmd.entrypoints.train import (

Check warning on line 283 in deepmd/model/pairtab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model/pairtab.py#L283

Added line #L283 was not covered by tests
update_one_sel,
)
Comment on lines +283 to +285

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.entrypoints.train
begins an import cycle.

local_jdata_cpy = local_jdata.copy()
return update_one_sel(global_jdata, local_jdata_cpy, True)

Check warning on line 288 in deepmd/model/pairtab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model/pairtab.py#L287-L288

Added lines #L287 - L288 were not covered by tests
21 changes: 21 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,7 @@ def model_args(exclude_hybrid=False):
standard_model_args(),
multi_model_args(),
frozen_model_args(),
pairtab_model_args(),
*hybrid_models,
],
optional=True,
Expand Down Expand Up @@ -1013,6 +1014,26 @@ def frozen_model_args() -> Argument:
return ca


def pairtab_model_args() -> Argument:
doc_tab_file = "Path to the tabulation file."
doc_rcut = "The cut-off radius."
doc_sel = 'This parameter set the number of selected neighbors. Note that this parameter is a little different from that in other descriptors. Instead of separating each type of atoms, only the summation matters. And this number is highly related with the efficiency, thus one should not make it too large. Usually 200 or less is enough, far away from the GPU limitation 4096. It can be:\n\n\
- `int`. The maximum number of neighbor atoms to be considered. We recommend it to be less than 200. \n\n\
- `List[int]`. The length of the list should be the same as the number of atom types in the system. `sel[i]` gives the selected number of type-i neighbors. Only the summation of `sel[i]` matters, and it is recommended to be less than 200.\
- `str`. Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors with in the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally the number is wraped up to 4 divisible. The option "auto" is equivalent to "auto:1.1".'
ca = Argument(
"pairtab",
dict,
[
Argument("tab_file", str, optional=False, doc=doc_tab_file),
Argument("rcut", float, optional=False, doc=doc_rcut),
Argument("sel", [int, List[int], str], optional=False, doc=doc_sel),
],
doc="Pairwise tabulation energy model.",
)
return ca


def linear_ener_model_args() -> Argument:
doc_weights = (
"If the type is list of float, a list of weights for each model. "
Expand Down
2 changes: 1 addition & 1 deletion doc/model/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
- [Deep potential long-range](dplr.md)
- [Deep Potential - Range Correction (DPRc)](dprc.md)
- [Linear model](linear.md)
- [Interpolation with a pairwise potential](pairtab.md)
- [Interpolation or combination with a pairwise potential](pairtab.md)
Loading
Loading