Skip to content

Commit

Permalink
Feat: add polar consistency test (#3327)
Browse files Browse the repository at this point in the history
This PR is to add cross framework consistency test on PolarFittingNet.

Note: `shift_diag` not yet implemented in PT.

---------

Signed-off-by: Anyang Peng <137014849+anyangml@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
anyangml and pre-commit-ci[bot] authored Feb 23, 2024
1 parent 260ef21 commit 543276a
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 19 deletions.
15 changes: 12 additions & 3 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def __init__(
fit_diag: bool = True,
scale: Optional[List[float]] = None,
shift_diag: bool = True,
# not used
seed: Optional[int] = None,
):
# seed, uniform_seed are not included
if tot_ener_zero:
Expand All @@ -119,9 +121,16 @@ def __init__(
if self.scale is None:
self.scale = [1.0 for _ in range(ntypes)]
else:
assert (
isinstance(self.scale, list) and len(self.scale) == ntypes
), "Scale should be a list of length ntypes."
if isinstance(self.scale, list):
assert (
len(self.scale) == ntypes
), "Scale should be a list of length ntypes."
elif isinstance(self.scale, float):
self.scale = [self.scale for _ in range(ntypes)]
else:
raise ValueError(
"Scale must be a list of float of length ntypes or a float."
)
self.scale = np.array(self.scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape(
ntypes, 1
)
Expand Down
16 changes: 12 additions & 4 deletions deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import (
List,
Optional,
Union,
)

import torch
Expand Down Expand Up @@ -85,7 +86,7 @@ def __init__(
seed: Optional[int] = None,
exclude_types: List[int] = [],
fit_diag: bool = True,
scale: Optional[List[float]] = None,
scale: Optional[Union[List[float], float]] = None,
shift_diag: bool = True,
**kwargs,
):
Expand All @@ -95,9 +96,16 @@ def __init__(
if self.scale is None:
self.scale = [1.0 for _ in range(ntypes)]
else:
assert (
isinstance(self.scale, list) and len(self.scale) == ntypes
), "Scale should be a list of length ntypes."
if isinstance(self.scale, list):
assert (
len(self.scale) == ntypes
), "Scale should be a list of length ntypes."
elif isinstance(self.scale, float):
self.scale = [self.scale for _ in range(ntypes)]
else:
raise ValueError(
"Scale must be a list of float of length ntypes or a float."
)
self.scale = torch.tensor(
self.scale, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
).view(ntypes, 1)
Expand Down
97 changes: 87 additions & 10 deletions deepmd/tf/fit/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ class PolarFittingSeA(Fitting):
Parameters
----------
descrpt : tf.Tensor
The descrptor
ntypes
The ntypes of the descrptor :math:`\mathcal{D}`
dim_descrpt
The dimension of the descrptor :math:`\mathcal{D}`
embedding_width
The rotation matrix dimension of the descrptor :math:`\mathcal{D}`
neuron : List[int]
Number of neurons in each hidden layer of the fitting net
resnet_dt : bool
Expand All @@ -69,7 +73,9 @@ class PolarFittingSeA(Fitting):

def __init__(
self,
descrpt: tf.Tensor,
ntypes: int,
dim_descrpt: int,
embedding_width: int,
neuron: List[int] = [120, 120, 120],
resnet_dt: bool = True,
sel_type: Optional[List[int]] = None,
Expand All @@ -84,8 +90,8 @@ def __init__(
**kwargs,
) -> None:
"""Constructor."""
self.ntypes = descrpt.get_ntypes()
self.dim_descrpt = descrpt.get_dim_out()
self.ntypes = ntypes
self.dim_descrpt = dim_descrpt
self.n_neuron = neuron
self.resnet_dt = resnet_dt
self.sel_type = sel_type
Expand All @@ -96,6 +102,7 @@ def __init__(
# self.diag_shift = diag_shift
self.shift_diag = shift_diag
self.scale = scale
self.activation_function_name = activation_function
self.fitting_activation_fn = get_activation_func(activation_function)
self.fitting_precision = get_precision(precision)
if self.sel_type is None:
Expand All @@ -104,7 +111,19 @@ def __init__(
[ii in self.sel_type for ii in range(self.ntypes)], dtype=bool
)
if self.scale is None:
self.scale = [1.0 for ii in range(self.ntypes)]
self.scale = np.array([1.0 for ii in range(self.ntypes)])
else:
if isinstance(self.scale, list):
assert (
len(self.scale) == ntypes
), "Scale should be a list of length ntypes."
elif isinstance(self.scale, float):
self.scale = [self.scale for _ in range(ntypes)]
else:
raise ValueError(
"Scale must be a list of float of length ntypes or a float."
)
self.scale = np.array(self.scale)
# if self.diag_shift is None:
# self.diag_shift = [0.0 for ii in range(self.ntypes)]
if not isinstance(self.sel_type, list):
Expand All @@ -115,10 +134,7 @@ def __init__(
) # self.ntypes x 1, store the average diagonal value
# if type(self.diag_shift) is not list:
# self.diag_shift = [self.diag_shift]
if not isinstance(self.scale, list):
self.scale = [self.scale for ii in range(self.ntypes)]
self.scale = np.array(self.scale)
self.dim_rot_mat_1 = descrpt.get_dim_rot_mat_1()
self.dim_rot_mat_1 = embedding_width
self.dim_rot_mat = self.dim_rot_mat_1 * 3
self.useBN = False
self.fitting_net_variables = None
Expand Down Expand Up @@ -509,6 +525,67 @@ def get_loss(self, loss: dict, lr) -> Loss:
label_name="polarizability",
)

def serialize(self, suffix: str) -> dict:
"""Serialize the model.
Returns
-------
dict
The serialized data
"""
data = {
"var_name": "polar",
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
"embedding_width": self.dim_rot_mat_1,
# very bad design: type embedding is not passed to the class
# TODO: refactor the class
"mixed_types": False,
"dim_out": 3,
"neuron": self.n_neuron,
"resnet_dt": self.resnet_dt,
"activation_function": self.activation_function_name,
"precision": self.fitting_precision.name,
"exclude_types": [],
"fit_diag": self.fit_diag,
"scale": list(self.scale),
"shift_diag": self.shift_diag,
"nets": self.serialize_network(
ntypes=self.ntypes,
# TODO: consider type embeddings
ndim=1,
in_dim=self.dim_descrpt,
out_dim=self.dim_rot_mat_1,
neuron=self.n_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.resnet_dt,
variables=self.fitting_net_variables,
suffix=suffix,
),
}
return data

@classmethod
def deserialize(cls, data: dict, suffix: str):
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
Model
The deserialized model
"""
fitting = cls(**data)
fitting.fitting_net_variables = cls.deserialize_network(
data["nets"],
suffix=suffix,
)
return fitting


class GlobalPolarFittingSeA:
r"""Fit the system polarizability with descriptor se_a.
Expand Down
Loading

0 comments on commit 543276a

Please sign in to comment.