Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

breaking(pt/dp): tune new sub-structures for DPA2 #4089

Merged
merged 9 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 169 additions & 6 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@
DescrptBlockRepformers,
RepformerLayer,
)
from .se_t_tebd import (
DescrptBlockSeTTebd,
)


class RepinitArgs:
Expand All @@ -75,6 +78,11 @@ def __init__(
activation_function="tanh",
resnet_dt: bool = False,
type_one_side: bool = False,
use_three_body: bool = False,
three_body_neuron: List[int] = [2, 4, 8],
three_body_sel: int = 40,
three_body_rcut: float = 4.0,
three_body_rcut_smth: float = 0.5,
iProzd marked this conversation as resolved.
Show resolved Hide resolved
):
r"""The constructor for the RepinitArgs class which defines the parameters of the repinit block in DPA2 descriptor.

Expand Down Expand Up @@ -104,6 +112,19 @@ def __init__(
Whether to use a "Timestep" in the skip connection.
type_one_side : bool, optional
Whether to use one-side type embedding.
use_three_body : bool, optional
Whether to concatenate three-body representation in the output descriptor.
three_body_neuron : list, optional
Number of neurons in each hidden layers of the three-body embedding net.
When two layers are of the same size or one layer is twice as large as the previous layer,
a skip connection is built.
three_body_sel : int, optional
Maximally possible number of selected neighbors in the three-body representation.
three_body_rcut : float, optional
The cut-off radius in the three-body representation.
three_body_rcut_smth : float, optional
Where to start smoothing in the three-body representation.
For example the 1/r term is smoothed from three_body_rcut to three_body_rcut_smth.
"""
self.rcut = rcut
self.rcut_smth = rcut_smth
Expand All @@ -116,6 +137,11 @@ def __init__(
self.activation_function = activation_function
self.resnet_dt = resnet_dt
self.type_one_side = type_one_side
self.use_three_body = use_three_body
self.three_body_neuron = three_body_neuron
self.three_body_sel = three_body_sel
self.three_body_rcut = three_body_rcut
self.three_body_rcut_smth = three_body_rcut_smth

def __getitem__(self, key):
if hasattr(self, key):
Expand All @@ -136,6 +162,11 @@ def serialize(self) -> dict:
"activation_function": self.activation_function,
"resnet_dt": self.resnet_dt,
"type_one_side": self.type_one_side,
"use_three_body": self.use_three_body,
"three_body_neuron": self.three_body_neuron,
"three_body_sel": self.three_body_sel,
"three_body_rcut": self.three_body_rcut,
"three_body_rcut_smth": self.three_body_rcut_smth,
}

@classmethod
Expand Down Expand Up @@ -172,6 +203,9 @@ def __init__(
update_residual_init: str = "norm",
set_davg_zero: bool = True,
trainable_ln: bool = True,
use_sqrt_nnei: bool = True,
g1_out_conv: bool = True,
g1_out_mlp: bool = True,
iProzd marked this conversation as resolved.
Show resolved Hide resolved
ln_eps: Optional[float] = 1e-5,
):
r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor.
Expand Down Expand Up @@ -236,6 +270,12 @@ def __init__(
Set the normalization average to zero.
trainable_ln : bool, optional
Whether to use trainable shift and scale weights in layer normalization.
use_sqrt_nnei : bool, optional
Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly.
g1_out_conv : bool, optional
Whether to put the convolutional update of g1 separately outside the concatenated MLP update.
g1_out_mlp : bool, optional
Whether to put the self MLP update of g1 separately outside the concatenated MLP update.
ln_eps : float, optional
The epsilon value for layer normalization.
"""
Expand Down Expand Up @@ -265,6 +305,9 @@ def __init__(
self.update_residual_init = update_residual_init
self.set_davg_zero = set_davg_zero
self.trainable_ln = trainable_ln
self.use_sqrt_nnei = use_sqrt_nnei
self.g1_out_conv = g1_out_conv
self.g1_out_mlp = g1_out_mlp
# to keep consistent with default value in this backends
if ln_eps is None:
ln_eps = 1e-5
Expand Down Expand Up @@ -304,6 +347,9 @@ def serialize(self) -> dict:
"update_residual_init": self.update_residual_init,
"set_davg_zero": self.set_davg_zero,
"trainable_ln": self.trainable_ln,
"use_sqrt_nnei": self.use_sqrt_nnei,
"g1_out_conv": self.g1_out_conv,
"g1_out_mlp": self.g1_out_mlp,
"ln_eps": self.ln_eps,
}

Expand Down Expand Up @@ -416,6 +462,27 @@ def init_subclass_params(sub_data, sub_class):
type_one_side=self.repinit_args.type_one_side,
seed=child_seed(seed, 0),
)
self.use_three_body = self.repinit_args.use_three_body
if self.use_three_body:
self.repinit_three_body = DescrptBlockSeTTebd(
self.repinit_args.three_body_rcut,
self.repinit_args.three_body_rcut_smth,
self.repinit_args.three_body_sel,
ntypes,
neuron=self.repinit_args.three_body_neuron,
tebd_dim=self.repinit_args.tebd_dim,
tebd_input_mode=self.repinit_args.tebd_input_mode,
set_davg_zero=self.repinit_args.set_davg_zero,
exclude_types=exclude_types,
env_protection=env_protection,
activation_function=self.repinit_args.activation_function,
precision=precision,
resnet_dt=self.repinit_args.resnet_dt,
smooth=smooth,
seed=child_seed(seed, 5),
)
else:
self.repinit_three_body = None
self.repformers = DescrptBlockRepformers(
self.repformer_args.rcut,
self.repformer_args.rcut_smth,
Expand Down Expand Up @@ -448,9 +515,27 @@ def init_subclass_params(sub_data, sub_class):
env_protection=env_protection,
precision=precision,
trainable_ln=self.repformer_args.trainable_ln,
use_sqrt_nnei=self.repformer_args.use_sqrt_nnei,
g1_out_conv=self.repformer_args.g1_out_conv,
g1_out_mlp=self.repformer_args.g1_out_mlp,
ln_eps=self.repformer_args.ln_eps,
seed=child_seed(seed, 1),
)
self.rcsl_list = [
(self.repformers.get_rcut(), self.repformers.get_nsel()),
(self.repinit.get_rcut(), self.repinit.get_nsel()),
]
if self.use_three_body:
self.rcsl_list.append(
(self.repinit_three_body.get_rcut(), self.repinit_three_body.get_nsel())
)
self.rcsl_list.sort()
for ii in range(1, len(self.rcsl_list)):
assert (
self.rcsl_list[ii - 1][1] <= self.rcsl_list[ii][1]
), "rcut and sel are not in the same order"
self.rcut_list = [ii[0] for ii in self.rcsl_list]
self.nsel_list = [ii[1] for ii in self.rcsl_list]
self.use_econf_tebd = use_econf_tebd
self.use_tebd_bias = use_tebd_bias
self.type_map = type_map
Expand All @@ -473,11 +558,16 @@ def init_subclass_params(sub_data, sub_class):
self.trainable = trainable
self.add_tebd_to_repinit_out = add_tebd_to_repinit_out

if self.repinit.dim_out == self.repformers.dim_in:
self.repinit_out_dim = self.repinit.dim_out
if self.repinit_args.use_three_body:
assert self.repinit_three_body is not None
self.repinit_out_dim += self.repinit_three_body.dim_out

if self.repinit_out_dim == self.repformers.dim_in:
self.g1_shape_tranform = Identity()
else:
self.g1_shape_tranform = NativeLayer(
self.repinit.dim_out,
self.repinit_out_dim,
self.repformers.dim_in,
bias=False,
precision=precision,
Expand Down Expand Up @@ -585,6 +675,7 @@ def change_type_map(
self.ntypes = len(type_map)
repinit = self.repinit
repformers = self.repformers
repinit_three_body = self.repinit_three_body
if has_new_type:
# the avg and std of new types need to be updated
extend_descrpt_stat(
Expand All @@ -601,6 +692,14 @@ def change_type_map(
if model_with_new_type_stat is not None
else None,
)
if self.use_three_body:
extend_descrpt_stat(
repinit_three_body,
type_map,
des_with_stat=model_with_new_type_stat.repinit_three_body
if model_with_new_type_stat is not None
else None,
)
repinit.ntypes = self.ntypes
repformers.ntypes = self.ntypes
repinit.reinit_exclude(self.exclude_types)
Expand All @@ -609,6 +708,11 @@ def change_type_map(
repinit["dstd"] = repinit["dstd"][remap_index]
repformers["davg"] = repformers["davg"][remap_index]
repformers["dstd"] = repformers["dstd"][remap_index]
if self.use_three_body:
repinit_three_body.ntypes = self.ntypes
repinit_three_body.reinit_exclude(self.exclude_types)
repinit_three_body["davg"] = repinit_three_body["davg"][remap_index]
repinit_three_body["dstd"] = repinit_three_body["dstd"][remap_index]

@property
def dim_out(self):
Expand Down Expand Up @@ -677,14 +781,15 @@ def call(
The smooth switch function. shape: nf x nloc x nnei

"""
use_three_body = self.use_three_body
nframes, nloc, nnei = nlist.shape
nall = coord_ext.reshape(nframes, -1).shape[1] // 3
# nlists
nlist_dict = build_multiple_neighbor_list(
coord_ext,
nlist,
[self.repformers.get_rcut(), self.repinit.get_rcut()],
[self.repformers.get_nsel(), self.repinit.get_nsel()],
self.rcut_list,
self.nsel_list,
)
# repinit
g1_ext = self.type_embedding.call()[atype_ext]
Expand All @@ -698,6 +803,21 @@ def call(
g1_ext,
mapping,
)
if use_three_body:
assert self.repinit_three_body is not None
g1_three_body, __, __, __, __ = self.repinit_three_body(
nlist_dict[
get_multiple_nlist_key(
self.repinit_three_body.get_rcut(),
self.repinit_three_body.get_nsel(),
)
],
coord_ext,
atype_ext,
g1_ext,
mapping,
)
g1 = np.concatenate([g1, g1_three_body], axis=-1)
# linear to change shape
g1 = self.g1_shape_tranform(g1)
if self.add_tebd_to_repinit_out:
Expand Down Expand Up @@ -726,10 +846,11 @@ def call(
def serialize(self) -> dict:
repinit = self.repinit
repformers = self.repformers
repinit_three_body = self.repinit_three_body
data = {
"@class": "Descriptor",
"type": "dpa2",
"@version": 2,
"@version": 3,
"ntypes": self.ntypes,
"repinit_args": self.repinit_args.serialize(),
"repformer_args": self.repformer_args.serialize(),
Expand Down Expand Up @@ -779,16 +900,43 @@ def serialize(self) -> dict:
"repformers_variable": repformers_variable,
}
)
if self.use_three_body:
repinit_three_body_variable = {
"embeddings": repinit_three_body.embeddings.serialize(),
"env_mat": EnvMat(
repinit_three_body.rcut, repinit_three_body.rcut_smth
).serialize(),
"@variables": {
"davg": repinit_three_body["davg"],
"dstd": repinit_three_body["dstd"],
},
}
if repinit_three_body.tebd_input_mode in ["strip"]:
repinit_three_body_variable.update(
{
"embeddings_strip": repinit_three_body.embeddings_strip.serialize()
}
)
data.update(
{
"repinit_three_body_variable": repinit_three_body_variable,
}
)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
return data

@classmethod
def deserialize(cls, data: dict) -> "DescrptDPA2":
data = data.copy()
check_version_compatibility(data.pop("@version"), 2, 1)
check_version_compatibility(data.pop("@version"), 3, 3)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
data.pop("@class")
data.pop("type")
repinit_variable = data.pop("repinit_variable").copy()
repformers_variable = data.pop("repformers_variable").copy()
repinit_three_body_variable = (
data.pop("repinit_three_body_variable").copy()
if "repinit_three_body_variable" in data
else None
)
type_embedding = data.pop("type_embedding")
g1_shape_tranform = data.pop("g1_shape_tranform")
tebd_transform = data.pop("tebd_transform", None)
Expand Down Expand Up @@ -820,6 +968,21 @@ def deserialize(cls, data: dict) -> "DescrptDPA2":
obj.repinit["davg"] = statistic_repinit["davg"]
obj.repinit["dstd"] = statistic_repinit["dstd"]

if data["repinit"].use_three_body:
# deserialize repinit_three_body
statistic_repinit_three_body = repinit_three_body_variable.pop("@variables")
env_mat = repinit_three_body_variable.pop("env_mat")
tebd_input_mode = data["repinit"].tebd_input_mode
obj.repinit_three_body.embeddings = NetworkCollection.deserialize(
repinit_three_body_variable.pop("embeddings")
)
if tebd_input_mode in ["strip"]:
obj.repinit_three_body.embeddings_strip = NetworkCollection.deserialize(
repinit_three_body_variable.pop("embeddings_strip")
)
obj.repinit_three_body["davg"] = statistic_repinit_three_body["davg"]
obj.repinit_three_body["dstd"] = statistic_repinit_three_body["dstd"]

# deserialize repformers
statistic_repformers = repformers_variable.pop("@variables")
env_mat = repformers_variable.pop("env_mat")
Expand Down
Loading
Loading