Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt/dp): support case embedding and sharable fitting #4417

Merged
merged 16 commits into from
Nov 28, 2024
7 changes: 7 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def get_sel(self) -> list[int]:
"""Get the neighbor selection."""
return self.descriptor.get_sel()

def set_caseid(self, case_idx):
"""
Set the case identification of this atomic model by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
self.fitting.set_caseid(case_idx)

def mixed_types(self) -> bool:
"""If true, the model
1. assumes total number of atoms aligned across frames;
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ def get_model_rcuts(self) -> list[float]:
def get_sel(self) -> list[int]:
return [max([model.get_nsel() for model in self.models])]

def set_caseid(self, case_idx):
"""
Set the case identification of this atomic model by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
for model in self.models:
model.set_caseid(case_idx)

def get_model_nsels(self) -> list[int]:
"""Get the processed sels for each individual models. Not distinguishing types."""
return [model.get_nsel() for model in self.models]
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def get_sel(self) -> list[int]:
"""Returns the number of selected atoms for each type."""
pass

@abstractmethod
def set_caseid(self, case_idx) -> None:
"""
Set the case identification of this atomic model by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
pass

def get_nsel(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return sum(self.get_sel())
Expand Down
9 changes: 9 additions & 0 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ def get_type_map(self) -> list[str]:
def get_sel(self) -> list[int]:
return [self.sel]

def set_caseid(self, case_idx):
"""
Set the case identification of this atomic model by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
raise NotImplementedError(
"Case identification not supported for PairTabAtomicModel!"
)

def get_nsel(self) -> int:
return self.sel

Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_caseid: int = 0,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[list[bool]] = None,
Expand Down Expand Up @@ -130,6 +131,7 @@ def __init__(
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
numb_caseid=numb_caseid,
rcond=rcond,
tot_ener_zero=tot_ener_zero,
trainable=trainable,
Expand Down Expand Up @@ -159,7 +161,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
check_version_compatibility(data.pop("@version", 1), 3, 1)
var_name = data.pop("var_name", None)
assert var_name == "dipole"
return super().deserialize(data)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/dos_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_caseid: int = 0,
bias_dos: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
trainable: Union[bool, list[bool]] = True,
Expand All @@ -60,6 +61,7 @@ def __init__(
bias_atom=bias_dos,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
numb_caseid=numb_caseid,
rcond=rcond,
trainable=trainable,
activation_function=activation_function,
Expand All @@ -73,7 +75,7 @@ def __init__(
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
check_version_compatibility(data.pop("@version", 1), 3, 1)
data["numb_dos"] = data.pop("dim_out")
data.pop("tot_ener_zero", None)
data.pop("var_name", None)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/ener_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_caseid: int = 0,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[list[bool]] = None,
Expand All @@ -55,6 +56,7 @@ def __init__(
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
numb_caseid=numb_caseid,
rcond=rcond,
tot_ener_zero=tot_ener_zero,
trainable=trainable,
Expand All @@ -73,7 +75,7 @@ def __init__(
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
check_version_compatibility(data.pop("@version", 1), 3, 1)
data.pop("var_name")
data.pop("dim_out")
return super().deserialize(data)
Expand Down
35 changes: 34 additions & 1 deletion deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_caseid: int = 0,
bias_atom_e: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
Expand All @@ -127,6 +128,7 @@ def __init__(
self.resnet_dt = resnet_dt
self.numb_fparam = numb_fparam
self.numb_aparam = numb_aparam
self.numb_caseid = numb_caseid
self.rcond = rcond
self.tot_ener_zero = tot_ener_zero
self.trainable = trainable
Expand Down Expand Up @@ -171,11 +173,16 @@ def __init__(
self.aparam_inv_std = np.ones(self.numb_aparam, dtype=self.prec)
else:
self.aparam_avg, self.aparam_inv_std = None, None
if self.numb_caseid > 0:
self.caseid = np.zeros(self.numb_caseid, dtype=self.prec)
else:
self.caseid = None
# init networks
in_dim = (
self.dim_descrpt
+ self.numb_fparam
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
+ self.numb_caseid
)
self.nets = NetworkCollection(
1 if not self.mixed_types else 0,
Expand Down Expand Up @@ -222,6 +229,13 @@ def get_type_map(self) -> list[str]:
"""Get the name to each type of atoms."""
return self.type_map

def set_caseid(self, case_idx):
"""
Set the case identification of this fitting net by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
self.caseid = np.eye(self.numb_caseid, dtype=self.prec)[case_idx]

def change_type_map(
self, type_map: list[str], model_with_new_type_stat=None
) -> None:
Expand Down Expand Up @@ -255,6 +269,8 @@ def __setitem__(self, key, value) -> None:
self.aparam_avg = value
elif key in ["aparam_inv_std"]:
self.aparam_inv_std = value
elif key in ["caseid"]:
self.caseid = value
elif key in ["scale"]:
self.scale = value
else:
Expand All @@ -271,6 +287,8 @@ def __getitem__(self, key):
return self.aparam_avg
elif key in ["aparam_inv_std"]:
return self.aparam_inv_std
elif key in ["caseid"]:
return self.caseid
elif key in ["scale"]:
return self.scale
else:
Expand All @@ -287,14 +305,15 @@ def serialize(self) -> dict:
"""Serialize the fitting to dict."""
return {
"@class": "Fitting",
"@version": 2,
"@version": 3,
"var_name": self.var_name,
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
"neuron": self.neuron,
"resnet_dt": self.resnet_dt,
"numb_fparam": self.numb_fparam,
"numb_aparam": self.numb_aparam,
"numb_caseid": self.numb_caseid,
"rcond": self.rcond,
"activation_function": self.activation_function,
"precision": self.precision,
Expand All @@ -303,6 +322,7 @@ def serialize(self) -> dict:
"nets": self.nets.serialize(),
"@variables": {
"bias_atom_e": to_numpy_array(self.bias_atom_e),
"caseid": to_numpy_array(self.caseid),
"fparam_avg": to_numpy_array(self.fparam_avg),
"fparam_inv_std": to_numpy_array(self.fparam_inv_std),
"aparam_avg": to_numpy_array(self.aparam_avg),
Expand Down Expand Up @@ -423,6 +443,19 @@ def _call_common(
axis=-1,
)

if self.numb_caseid > 0:
assert self.caseid is not None
caseid = xp.tile(xp.reshape(self.caseid, [1, 1, -1]), [nf, nloc, 1])
xx = xp.concat(
[xx, caseid],
axis=-1,
)
if xx_zeros is not None:
xx_zeros = xp.concat(
[xx_zeros, caseid],
axis=-1,
)

# calculate the prediction
if not self.mixed_types:
outs = xp.zeros(
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_caseid: int = 0,
bias_atom: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
Expand Down Expand Up @@ -155,6 +156,7 @@ def __init__(
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
numb_caseid=numb_caseid,
rcond=rcond,
bias_atom_e=bias_atom,
tot_ener_zero=tot_ener_zero,
Expand Down Expand Up @@ -183,7 +185,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
check_version_compatibility(data.pop("@version", 1), 3, 1)
return super().deserialize(data)

def _net_out_dim(self):
Expand Down
6 changes: 4 additions & 2 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_caseid: int = 0,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[list[bool]] = None,
Expand Down Expand Up @@ -150,6 +151,7 @@ def __init__(
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
numb_caseid=numb_caseid,
rcond=rcond,
tot_ener_zero=tot_ener_zero,
trainable=trainable,
Expand Down Expand Up @@ -187,7 +189,7 @@ def __getitem__(self, key):
def serialize(self) -> dict:
data = super().serialize()
data["type"] = "polar"
data["@version"] = 3
data["@version"] = 4
data["embedding_width"] = self.embedding_width
data["fit_diag"] = self.fit_diag
data["shift_diag"] = self.shift_diag
Expand All @@ -198,7 +200,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 3, 1)
check_version_compatibility(data.pop("@version", 1), 4, 1)
var_name = data.pop("var_name", None)
assert var_name == "polar"
return super().deserialize(data)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/property_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_caseid: int = 0,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
mixed_types: bool = True,
Expand All @@ -99,6 +100,7 @@ def __init__(
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
numb_caseid=numb_caseid,
rcond=rcond,
trainable=trainable,
activation_function=activation_function,
Expand All @@ -111,7 +113,7 @@ def __init__(
@classmethod
def deserialize(cls, data: dict) -> "PropertyFittingNet":
data = data.copy()
check_version_compatibility(data.pop("@version"), 2, 1)
check_version_compatibility(data.pop("@version"), 3, 1)
data.pop("dim_out")
data.pop("var_name")
data.pop("tot_ener_zero")
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,9 @@ def serialize(self) -> dict:
def deserialize(cls, data) -> "CM":
return cls(atomic_model_=T_AtomicModel.deserialize(data))

def set_caseid(self, case_idx):
self.atomic_model.set_caseid(case_idx)

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
return self.atomic_model.get_dim_fparam()
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ def get_sel(self) -> list[int]:
"""Get the neighbor selection."""
return self.sel

def set_caseid(self, case_idx):
"""
Set the case identification of this atomic model by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
self.fitting_net.set_caseid(case_idx)

def mixed_types(self) -> bool:
"""If true, the model
1. assumes total number of atoms aligned across frames;
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ def get_model_rcuts(self) -> list[float]:
def get_sel(self) -> list[int]:
return [max([model.get_nsel() for model in self.models])]

def set_caseid(self, case_idx):
"""
Set the case identification of this atomic model by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
for model in self.models:
model.set_caseid(case_idx)

def get_model_nsels(self) -> list[int]:
"""Get the processed sels for each individual models. Not distinguishing types."""
return [model.get_nsel() for model in self.models]
Expand Down
9 changes: 9 additions & 0 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ def get_type_map(self) -> list[str]:
def get_sel(self) -> list[int]:
return [self.sel]

def set_caseid(self, case_idx):
"""
Set the case identification of this atomic model by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
raise NotImplementedError(
"Case identification not supported for PairTabAtomicModel!"
)

def get_nsel(self) -> int:
return self.sel

Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,9 @@ def serialize(self) -> dict:
def deserialize(cls, data) -> "CM":
return cls(atomic_model_=T_AtomicModel.deserialize(data))

def set_caseid(self, case_idx):
self.atomic_model.set_caseid(case_idx)

@torch.jit.export
def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
Expand Down
Loading
Loading