Skip to content

Commit

Permalink
New substru dpa2 (#33)
Browse files Browse the repository at this point in the history
* update first three subs

* add three body for repinit

* resolve conversation

* breaking change the default values

* fix uts

* fix uts
  • Loading branch information
iProzd authored Sep 6, 2024
1 parent 9089d2a commit 048d75f
Show file tree
Hide file tree
Showing 11 changed files with 724 additions and 58 deletions.
175 changes: 169 additions & 6 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@
DescrptBlockRepformers,
RepformerLayer,
)
from .se_t_tebd import (
DescrptBlockSeTTebd,
)


class RepinitArgs:
Expand All @@ -75,6 +78,11 @@ def __init__(
activation_function="tanh",
resnet_dt: bool = False,
type_one_side: bool = False,
use_three_body: bool = False,
three_body_neuron: List[int] = [2, 4, 8],
three_body_sel: int = 40,
three_body_rcut: float = 4.0,
three_body_rcut_smth: float = 0.5,
):
r"""The constructor for the RepinitArgs class which defines the parameters of the repinit block in DPA2 descriptor.
Expand Down Expand Up @@ -104,6 +112,19 @@ def __init__(
Whether to use a "Timestep" in the skip connection.
type_one_side : bool, optional
Whether to use one-side type embedding.
use_three_body : bool, optional
Whether to concatenate three-body representation in the output descriptor.
three_body_neuron : list, optional
Number of neurons in each hidden layers of the three-body embedding net.
When two layers are of the same size or one layer is twice as large as the previous layer,
a skip connection is built.
three_body_sel : int, optional
Maximally possible number of selected neighbors in the three-body representation.
three_body_rcut : float, optional
The cut-off radius in the three-body representation.
three_body_rcut_smth : float, optional
Where to start smoothing in the three-body representation.
For example the 1/r term is smoothed from three_body_rcut to three_body_rcut_smth.
"""
self.rcut = rcut
self.rcut_smth = rcut_smth
Expand All @@ -116,6 +137,11 @@ def __init__(
self.activation_function = activation_function
self.resnet_dt = resnet_dt
self.type_one_side = type_one_side
self.use_three_body = use_three_body
self.three_body_neuron = three_body_neuron
self.three_body_sel = three_body_sel
self.three_body_rcut = three_body_rcut
self.three_body_rcut_smth = three_body_rcut_smth

def __getitem__(self, key):
if hasattr(self, key):
Expand All @@ -136,6 +162,11 @@ def serialize(self) -> dict:
"activation_function": self.activation_function,
"resnet_dt": self.resnet_dt,
"type_one_side": self.type_one_side,
"use_three_body": self.use_three_body,
"three_body_neuron": self.three_body_neuron,
"three_body_sel": self.three_body_sel,
"three_body_rcut": self.three_body_rcut,
"three_body_rcut_smth": self.three_body_rcut_smth,
}

@classmethod
Expand Down Expand Up @@ -172,6 +203,9 @@ def __init__(
update_residual_init: str = "norm",
set_davg_zero: bool = True,
trainable_ln: bool = True,
use_sqrt_nnei: bool = True,
g1_out_conv: bool = True,
g1_out_mlp: bool = True,
ln_eps: Optional[float] = 1e-5,
):
r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor.
Expand Down Expand Up @@ -236,6 +270,12 @@ def __init__(
Set the normalization average to zero.
trainable_ln : bool, optional
Whether to use trainable shift and scale weights in layer normalization.
use_sqrt_nnei : bool, optional
Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly.
g1_out_conv : bool, optional
Whether to put the convolutional update of g1 separately outside the concatenated MLP update.
g1_out_mlp : bool, optional
Whether to put the self MLP update of g1 separately outside the concatenated MLP update.
ln_eps : float, optional
The epsilon value for layer normalization.
"""
Expand Down Expand Up @@ -265,6 +305,9 @@ def __init__(
self.update_residual_init = update_residual_init
self.set_davg_zero = set_davg_zero
self.trainable_ln = trainable_ln
self.use_sqrt_nnei = use_sqrt_nnei
self.g1_out_conv = g1_out_conv
self.g1_out_mlp = g1_out_mlp
# to keep consistent with default value in this backends
if ln_eps is None:
ln_eps = 1e-5
Expand Down Expand Up @@ -304,6 +347,9 @@ def serialize(self) -> dict:
"update_residual_init": self.update_residual_init,
"set_davg_zero": self.set_davg_zero,
"trainable_ln": self.trainable_ln,
"use_sqrt_nnei": self.use_sqrt_nnei,
"g1_out_conv": self.g1_out_conv,
"g1_out_mlp": self.g1_out_mlp,
"ln_eps": self.ln_eps,
}

Expand Down Expand Up @@ -416,6 +462,27 @@ def init_subclass_params(sub_data, sub_class):
type_one_side=self.repinit_args.type_one_side,
seed=child_seed(seed, 0),
)
self.use_three_body = self.repinit_args.use_three_body
if self.use_three_body:
self.repinit_three_body = DescrptBlockSeTTebd(
self.repinit_args.three_body_rcut,
self.repinit_args.three_body_rcut_smth,
self.repinit_args.three_body_sel,
ntypes,
neuron=self.repinit_args.three_body_neuron,
tebd_dim=self.repinit_args.tebd_dim,
tebd_input_mode=self.repinit_args.tebd_input_mode,
set_davg_zero=self.repinit_args.set_davg_zero,
exclude_types=exclude_types,
env_protection=env_protection,
activation_function=self.repinit_args.activation_function,
precision=precision,
resnet_dt=self.repinit_args.resnet_dt,
smooth=smooth,
seed=child_seed(seed, 5),
)
else:
self.repinit_three_body = None
self.repformers = DescrptBlockRepformers(
self.repformer_args.rcut,
self.repformer_args.rcut_smth,
Expand Down Expand Up @@ -448,9 +515,27 @@ def init_subclass_params(sub_data, sub_class):
env_protection=env_protection,
precision=precision,
trainable_ln=self.repformer_args.trainable_ln,
use_sqrt_nnei=self.repformer_args.use_sqrt_nnei,
g1_out_conv=self.repformer_args.g1_out_conv,
g1_out_mlp=self.repformer_args.g1_out_mlp,
ln_eps=self.repformer_args.ln_eps,
seed=child_seed(seed, 1),
)
self.rcsl_list = [
(self.repformers.get_rcut(), self.repformers.get_nsel()),
(self.repinit.get_rcut(), self.repinit.get_nsel()),
]
if self.use_three_body:
self.rcsl_list.append(
(self.repinit_three_body.get_rcut(), self.repinit_three_body.get_nsel())
)
self.rcsl_list.sort()
for ii in range(1, len(self.rcsl_list)):
assert (
self.rcsl_list[ii - 1][1] <= self.rcsl_list[ii][1]
), "rcut and sel are not in the same order"
self.rcut_list = [ii[0] for ii in self.rcsl_list]
self.nsel_list = [ii[1] for ii in self.rcsl_list]
self.use_econf_tebd = use_econf_tebd
self.use_tebd_bias = use_tebd_bias
self.type_map = type_map
Expand All @@ -473,11 +558,16 @@ def init_subclass_params(sub_data, sub_class):
self.trainable = trainable
self.add_tebd_to_repinit_out = add_tebd_to_repinit_out

if self.repinit.dim_out == self.repformers.dim_in:
self.repinit_out_dim = self.repinit.dim_out
if self.repinit_args.use_three_body:
assert self.repinit_three_body is not None
self.repinit_out_dim += self.repinit_three_body.dim_out

if self.repinit_out_dim == self.repformers.dim_in:
self.g1_shape_tranform = Identity()
else:
self.g1_shape_tranform = NativeLayer(
self.repinit.dim_out,
self.repinit_out_dim,
self.repformers.dim_in,
bias=False,
precision=precision,
Expand Down Expand Up @@ -585,6 +675,7 @@ def change_type_map(
self.ntypes = len(type_map)
repinit = self.repinit
repformers = self.repformers
repinit_three_body = self.repinit_three_body
if has_new_type:
# the avg and std of new types need to be updated
extend_descrpt_stat(
Expand All @@ -601,6 +692,14 @@ def change_type_map(
if model_with_new_type_stat is not None
else None,
)
if self.use_three_body:
extend_descrpt_stat(
repinit_three_body,
type_map,
des_with_stat=model_with_new_type_stat.repinit_three_body
if model_with_new_type_stat is not None
else None,
)
repinit.ntypes = self.ntypes
repformers.ntypes = self.ntypes
repinit.reinit_exclude(self.exclude_types)
Expand All @@ -609,6 +708,11 @@ def change_type_map(
repinit["dstd"] = repinit["dstd"][remap_index]
repformers["davg"] = repformers["davg"][remap_index]
repformers["dstd"] = repformers["dstd"][remap_index]
if self.use_three_body:
repinit_three_body.ntypes = self.ntypes
repinit_three_body.reinit_exclude(self.exclude_types)
repinit_three_body["davg"] = repinit_three_body["davg"][remap_index]
repinit_three_body["dstd"] = repinit_three_body["dstd"][remap_index]

@property
def dim_out(self):
Expand Down Expand Up @@ -677,14 +781,15 @@ def call(
The smooth switch function. shape: nf x nloc x nnei
"""
use_three_body = self.use_three_body
nframes, nloc, nnei = nlist.shape
nall = coord_ext.reshape(nframes, -1).shape[1] // 3
# nlists
nlist_dict = build_multiple_neighbor_list(
coord_ext,
nlist,
[self.repformers.get_rcut(), self.repinit.get_rcut()],
[self.repformers.get_nsel(), self.repinit.get_nsel()],
self.rcut_list,
self.nsel_list,
)
# repinit
g1_ext = self.type_embedding.call()[atype_ext]
Expand All @@ -698,6 +803,21 @@ def call(
g1_ext,
mapping,
)
if use_three_body:
assert self.repinit_three_body is not None
g1_three_body, __, __, __, __ = self.repinit_three_body(
nlist_dict[
get_multiple_nlist_key(
self.repinit_three_body.get_rcut(),
self.repinit_three_body.get_nsel(),
)
],
coord_ext,
atype_ext,
g1_ext,
mapping,
)
g1 = np.concatenate([g1, g1_three_body], axis=-1)
# linear to change shape
g1 = self.g1_shape_tranform(g1)
if self.add_tebd_to_repinit_out:
Expand Down Expand Up @@ -726,10 +846,11 @@ def call(
def serialize(self) -> dict:
repinit = self.repinit
repformers = self.repformers
repinit_three_body = self.repinit_three_body
data = {
"@class": "Descriptor",
"type": "dpa2",
"@version": 2,
"@version": 3,
"ntypes": self.ntypes,
"repinit_args": self.repinit_args.serialize(),
"repformer_args": self.repformer_args.serialize(),
Expand Down Expand Up @@ -779,16 +900,43 @@ def serialize(self) -> dict:
"repformers_variable": repformers_variable,
}
)
if self.use_three_body:
repinit_three_body_variable = {
"embeddings": repinit_three_body.embeddings.serialize(),
"env_mat": EnvMat(
repinit_three_body.rcut, repinit_three_body.rcut_smth
).serialize(),
"@variables": {
"davg": repinit_three_body["davg"],
"dstd": repinit_three_body["dstd"],
},
}
if repinit_three_body.tebd_input_mode in ["strip"]:
repinit_three_body_variable.update(
{
"embeddings_strip": repinit_three_body.embeddings_strip.serialize()
}
)
data.update(
{
"repinit_three_body_variable": repinit_three_body_variable,
}
)
return data

@classmethod
def deserialize(cls, data: dict) -> "DescrptDPA2":
data = data.copy()
check_version_compatibility(data.pop("@version"), 2, 1)
check_version_compatibility(data.pop("@version"), 3, 3)
data.pop("@class")
data.pop("type")
repinit_variable = data.pop("repinit_variable").copy()
repformers_variable = data.pop("repformers_variable").copy()
repinit_three_body_variable = (
data.pop("repinit_three_body_variable").copy()
if "repinit_three_body_variable" in data
else None
)
type_embedding = data.pop("type_embedding")
g1_shape_tranform = data.pop("g1_shape_tranform")
tebd_transform = data.pop("tebd_transform", None)
Expand Down Expand Up @@ -820,6 +968,21 @@ def deserialize(cls, data: dict) -> "DescrptDPA2":
obj.repinit["davg"] = statistic_repinit["davg"]
obj.repinit["dstd"] = statistic_repinit["dstd"]

if data["repinit"].use_three_body:
# deserialize repinit_three_body
statistic_repinit_three_body = repinit_three_body_variable.pop("@variables")
env_mat = repinit_three_body_variable.pop("env_mat")
tebd_input_mode = data["repinit"].tebd_input_mode
obj.repinit_three_body.embeddings = NetworkCollection.deserialize(
repinit_three_body_variable.pop("embeddings")
)
if tebd_input_mode in ["strip"]:
obj.repinit_three_body.embeddings_strip = NetworkCollection.deserialize(
repinit_three_body_variable.pop("embeddings_strip")
)
obj.repinit_three_body["davg"] = statistic_repinit_three_body["davg"]
obj.repinit_three_body["dstd"] = statistic_repinit_three_body["dstd"]

# deserialize repformers
statistic_repformers = repformers_variable.pop("@variables")
env_mat = repformers_variable.pop("env_mat")
Expand Down
Loading

0 comments on commit 048d75f

Please sign in to comment.