Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: consistent type embedding #3617

Merged
merged 5 commits into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions deepmd/dpmodel/utils/type_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
Optional,
)

import numpy as np

from deepmd.dpmodel.common import (
PRECISION_DICT,
NativeOP,
)
from deepmd.dpmodel.utils.network import (
EmbeddingNet,
)
from deepmd.utils.version import (
check_version_compatibility,
)


class TypeEmbedNet(NativeOP):
r"""Type embedding network.

Parameters
----------
ntypes : int
Number of atom types
neuron : list[int]
Number of neurons in each hidden layers of the embedding net
resnet_dt
Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b)
activation_function
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision
The precision of the embedding net parameters. Supported options are |PRECISION|
trainable
If the weights of embedding net are trainable.
seed
Random seed for initializing the network parameters.
padding
Concat the zero padding to the output, as the default embedding of empty type.
"""

def __init__(
self,
*,
ntypes: int,
neuron: List[int],
resnet_dt: bool = False,
activation_function: str = "tanh",
precision: str = "default",
trainable: bool = True,
seed: Optional[int] = None,
padding: bool = False,
) -> None:
self.ntypes = ntypes
self.neuron = neuron
self.seed = seed
self.resnet_dt = resnet_dt
self.precision = precision
self.activation_function = str(activation_function)
self.trainable = trainable
self.padding = padding
self.embedding_net = EmbeddingNet(
ntypes,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
)

def call(self) -> np.ndarray:
"""Compute the type embedding network."""
embed = self.embedding_net(
np.eye(self.ntypes, dtype=PRECISION_DICT[self.precision])
)
if self.padding:
embed = np.pad(embed, ((0, 1), (0, 0)), mode="constant")
return embed

@classmethod
def deserialize(cls, data: dict):
"""Deserialize the model.

Parameters
----------
data : dict
The serialized data

Returns
-------
Model
The deserialized model
"""
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
data_cls = data.pop("@class")
assert data_cls == "TypeEmbedNet", f"Invalid class {data_cls}"

embedding_net = EmbeddingNet.deserialize(data.pop("embedding"))
type_embedding_net = cls(**data)
type_embedding_net.embedding_net = embedding_net
return type_embedding_net

def serialize(self) -> dict:
"""Serialize the model.

Returns
-------
dict
The serialized data
"""
return {
"@class": "TypeEmbedNet",
"@version": 1,
"ntypes": self.ntypes,
"neuron": self.neuron,
"resnet_dt": self.resnet_dt,
"precision": self.precision,
"activation_function": self.activation_function,
"trainable": self.trainable,
"padding": self.padding,
"embedding": self.embedding_net.serialize(),
}
141 changes: 134 additions & 7 deletions deepmd/pt/model/network/network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
Optional,
)

Expand All @@ -8,9 +9,15 @@
import torch.nn as nn
import torch.nn.functional as F

from deepmd.pt.model.network.mlp import (
EmbeddingNet,
)
from deepmd.pt.utils import (
env,
)
from deepmd.utils.version import (
check_version_compatibility,
)

try:
from typing import (
Expand Down Expand Up @@ -552,12 +559,12 @@ class TypeEmbedNet(nn.Module):
def __init__(self, type_nums, embed_dim, bavg=0.0, stddev=1.0):
"""Construct a type embedding net."""
super().__init__()
self.embedding = nn.Embedding(
type_nums + 1,
embed_dim,
padding_idx=type_nums,
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
self.embedding = TypeEmbedNetConsistent(
ntypes=type_nums,
neuron=[embed_dim],
padding=True,
activation_function="Linear",
precision="default",
)
# nn.init.normal_(self.embedding.weight[:-1], mean=bavg, std=stddev)

Expand All @@ -571,7 +578,7 @@ def forward(self, atype):
type_embedding:

"""
return self.embedding(atype)
return self.embedding(atype.device)[atype]

def share_params(self, base_class, shared_level, resume=False):
"""
Expand All @@ -590,6 +597,126 @@ def share_params(self, base_class, shared_level, resume=False):
raise NotImplementedError


class TypeEmbedNetConsistent(nn.Module):
r"""Type embedding network that is consistent with other backends.

Parameters
----------
ntypes : int
Number of atom types
neuron : list[int]
Number of neurons in each hidden layers of the embedding net
resnet_dt
Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b)
activation_function
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision
The precision of the embedding net parameters. Supported options are |PRECISION|
trainable
If the weights of embedding net are trainable.
seed
Random seed for initializing the network parameters.
padding
Concat the zero padding to the output, as the default embedding of empty type.
"""

def __init__(
self,
*,
ntypes: int,
neuron: List[int],
resnet_dt: bool = False,
activation_function: str = "tanh",
precision: str = "default",
trainable: bool = True,
seed: Optional[int] = None,
padding: bool = False,
):
"""Construct a type embedding net."""
super().__init__()
self.ntypes = ntypes
self.neuron = neuron
self.seed = seed
self.resnet_dt = resnet_dt
self.precision = precision
self.prec = env.PRECISION_DICT[self.precision]
self.activation_function = str(activation_function)
self.trainable = trainable
self.padding = padding
# no way to pass seed?
self.embedding_net = EmbeddingNet(
ntypes,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
)
for param in self.parameters():
param.requires_grad = trainable

def forward(self, device: torch.device):
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
"""Caulate type embedding network.

Returns
-------
type_embedding: torch.Tensor
Type embedding network.
"""
embed = self.embedding_net(
torch.eye(self.ntypes, dtype=self.prec, device=device)
)
if self.padding:
embed = torch.cat(
[embed, torch.zeros(1, embed.shape[1], dtype=self.prec, device=device)]
)
return embed

@classmethod
def deserialize(cls, data: dict):
"""Deserialize the model.

Parameters
----------
data : dict
The serialized data

Returns
-------
TypeEmbedNetConsistent
The deserialized model
"""
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
data_cls = data.pop("@class")
assert data_cls == "TypeEmbedNet", f"Invalid class {data_cls}"

embedding_net = EmbeddingNet.deserialize(data.pop("embedding"))
type_embedding_net = cls(**data)
type_embedding_net.embedding_net = embedding_net
return type_embedding_net

def serialize(self) -> dict:
"""Serialize the model.

Returns
-------
dict
The serialized data
"""
return {
"@class": "TypeEmbedNet",
"@version": 1,
"ntypes": self.ntypes,
"neuron": self.neuron,
"resnet_dt": self.resnet_dt,
"precision": self.precision,
"activation_function": self.activation_function,
"trainable": self.trainable,
"padding": self.padding,
"embedding": self.embedding_net.serialize(),
}


@torch.jit.script
def gaussian(x, mean, std: float):
pi = 3.14159
Expand Down
11 changes: 7 additions & 4 deletions deepmd/tf/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,14 @@ def dlopen_library(module: str, filename: str):
r"share_.+/idt|"
)[:-1]

# subpatterns:
# \1: weight name
# \2: layer index
TYPE_EMBEDDING_PATTERN = str(
r"type_embed_net+/matrix_\d+|"
r"type_embed_net+/bias_\d+|"
r"type_embed_net+/idt_\d+|"
)
r"type_embed_net/(matrix)_(\d+)|"
r"type_embed_net/(bias)_(\d+)|"
r"type_embed_net/(idt)_(\d+)|"
)[:-1]

ATTENTION_LAYER_PATTERN = str(
r"attention_layer_\d+/c_query/matrix|"
Expand Down
2 changes: 2 additions & 0 deletions deepmd/tf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ def __init__(
self.typeebd = type_embedding
elif type_embedding is not None:
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**type_embedding,
padding=self.descrpt.explicit_ntypes,
)
Expand All @@ -686,6 +687,7 @@ def __init__(
default_args_dict = {i.name: i.default for i in default_args}
default_args_dict["activation_function"] = None
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**default_args_dict,
padding=True,
)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/tf/model/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,13 @@ def __init__(
dim_descrpt=self.descrpt.get_dim_out(),
)

self.ntypes = self.descrpt.get_ntypes()
# type embedding
if type_embedding is not None and isinstance(type_embedding, TypeEmbedNet):
self.typeebd = type_embedding
elif type_embedding is not None:
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**type_embedding,
padding=self.descrpt.explicit_ntypes,
)
Expand All @@ -159,6 +161,7 @@ def __init__(
default_args_dict = {i.name: i.default for i in default_args}
default_args_dict["activation_function"] = None
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**default_args_dict,
padding=True,
)
Expand All @@ -167,7 +170,6 @@ def __init__(

# descriptor
self.rcut = self.descrpt.get_rcut()
self.ntypes = self.descrpt.get_ntypes()
# fitting
self.fitting_dict = fitting_dict
self.numb_fparam_dict = {
Expand Down
3 changes: 2 additions & 1 deletion deepmd/tf/model/pairwise_dprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,13 @@ def __init__(
compress=compress,
**kwargs,
)
self.ntypes = len(type_map)
# type embedding
if isinstance(type_embedding, TypeEmbedNet):
self.typeebd = type_embedding
else:
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**type_embedding,
# must use se_atten, so it must be True
padding=True,
Expand All @@ -100,7 +102,6 @@ def __init__(
compress=compress,
)
add_data_requirement("aparam", 1, atomic=True, must=True, high_prec=False)
self.ntypes = len(type_map)
self.rcut = max(self.qm_model.get_rcut(), self.qmmm_model.get_rcut())

def build(
Expand Down
Loading