Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement hessian autodiff calculation #3262

Merged
merged 10 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepmd/dpmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
OutputVariableDef,
fitting_check_output,
get_deriv_name,
get_hessian_name,
get_reduce_name,
model_check_output,
)
Expand All @@ -31,4 +32,5 @@
"fitting_check_output",
"get_reduce_name",
"get_deriv_name",
"get_hessian_name",
]
4 changes: 4 additions & 0 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@
return name + "_derv_r", name + "_derv_c"


def get_hessian_name(name: str) -> str:
return name + "_derv_r_derv_r"

Check warning on line 323 in deepmd/dpmodel/output_def.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/output_def.py#L323

Added line #L323 was not covered by tests


def apply_operation(var_def: OutputVariableDef, op: OutputVariableOperation) -> int:
"""Apply a operation to the category of a variable definition.

Expand Down
3 changes: 2 additions & 1 deletion deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
t0 = 1 / length
t1 = diff / length**2
weight = compute_smooth_weight(length, ruct_smth, rcut)
env_mat_se_a = np.concatenate([t0, t1], axis=-1) * weight * np.expand_dims(mask, -1)
weight = weight * np.expand_dims(mask, -1)
env_mat_se_a = np.concatenate([t0, t1], axis=-1) * weight

Check warning on line 57 in deepmd/dpmodel/utils/env_mat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/env_mat.py#L56-L57

Added lines #L56 - L57 were not covered by tests
return env_mat_se_a, diff * np.expand_dims(mask, -1), weight


Expand Down
7 changes: 5 additions & 2 deletions deepmd/pt/model/descriptor/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
"""Make smooth environment matrix."""
bsz, natoms, nnei = nlist.shape
coord = coord.view(bsz, -1, 3)
nall = coord.shape[1]

Check warning on line 13 in deepmd/pt/model/descriptor/env_mat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/env_mat.py#L13

Added line #L13 was not covered by tests
mask = nlist >= 0
nlist = nlist * mask
# nlist = nlist * mask ## this impl will contribute nans in Hessian calculation.
nlist = torch.where(mask, nlist, nall - 1)

Check warning on line 16 in deepmd/pt/model/descriptor/env_mat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/env_mat.py#L16

Added line #L16 was not covered by tests
coord_l = coord[:, :natoms].view(bsz, -1, 1, 3)
index = nlist.view(bsz, -1).unsqueeze(-1).expand(-1, -1, 3)
coord_r = torch.gather(coord, 1, index)
Expand All @@ -23,7 +25,8 @@
t0 = 1 / length
t1 = diff / length**2
weight = compute_smooth_weight(length, ruct_smth, rcut)
env_mat_se_a = torch.cat([t0, t1], dim=-1) * weight * mask.unsqueeze(-1)
weight = weight * mask.unsqueeze(-1)
env_mat_se_a = torch.cat([t0, t1], dim=-1) * weight

Check warning on line 29 in deepmd/pt/model/descriptor/env_mat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/env_mat.py#L28-L29

Added lines #L28 - L29 were not covered by tests
return env_mat_se_a, diff * mask.unsqueeze(-1), weight


Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
EnergyModel,
ZBLModel,
)
from .make_hessian_model import (

Check warning on line 21 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L21

Added line #L21 was not covered by tests
make_hessian_model,
)
from .model import (
BaseModel,
)
Expand Down Expand Up @@ -84,4 +87,5 @@
"BaseModel",
"EnergyModel",
"get_model",
"make_hessian_model",
]
21 changes: 19 additions & 2 deletions deepmd/pt/model/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def forward(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)

model_predict = {}
Expand All @@ -63,13 +66,18 @@ def forward_lower(
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
model_ret = self.forward_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
mapping=mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)

model_predict = {}
Expand Down Expand Up @@ -109,7 +117,12 @@ def forward(
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
model_ret = self.forward_common(
coord, atype, box, do_atomic_virial=do_atomic_virial
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.fitting_net is not None:
model_predict = {}
Expand All @@ -135,13 +148,17 @@ def forward_lower(
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
model_ret = self.forward_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.fitting_net is not None:
Expand Down
216 changes: 216 additions & 0 deletions deepmd/pt/model/model/make_hessian_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import math
from typing import (

Check warning on line 4 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L2-L4

Added lines #L2 - L4 were not covered by tests
Dict,
List,
Optional,
Union,
)

import torch

Check warning on line 11 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L11

Added line #L11 was not covered by tests

from deepmd.dpmodel import (

Check warning on line 13 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L13

Added line #L13 was not covered by tests
get_hessian_name,
)


def make_hessian_model(T_Model):

Check warning on line 18 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L18

Added line #L18 was not covered by tests
"""Make a model that can compute Hessian.

LIMITATION: this model is not jitable due to the restrictions of torch jit script.

LIMITATION: only the hessian of `forward_common` is available.

Parameters
----------
T_Model
The model. Should provide the `forward_common` and `fitting_output_def` methods

Returns
-------
The model computes hessian.

"""

class CM(T_Model):
def __init__(

Check warning on line 37 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L36-L37

Added lines #L36 - L37 were not covered by tests
self,
*args,
**kwargs,
):
super().__init__(

Check warning on line 42 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L42

Added line #L42 was not covered by tests
*args,
**kwargs,
)
self.hess_fitting_def = copy.deepcopy(super().fitting_output_def())

Check warning on line 46 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L46

Added line #L46 was not covered by tests

def requires_hessian(

Check warning on line 48 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L48

Added line #L48 was not covered by tests
self,
keys: Union[str, List[str]],
):
"""Set which output variable(s) requires hessian."""
if isinstance(keys, str):
keys = [keys]
for kk in self.hess_fitting_def.keys():
if kk in keys:
self.hess_fitting_def[kk].r_hessian = True

Check warning on line 57 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L53-L57

Added lines #L53 - L57 were not covered by tests

def fitting_output_def(self):

Check warning on line 59 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L59

Added line #L59 was not covered by tests
"""Get the fitting output def."""
return self.hess_fitting_def

Check warning on line 61 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L61

Added line #L61 was not covered by tests

def forward_common(

Check warning on line 63 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L63

Added line #L63 was not covered by tests
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
"""Return model prediction.

Parameters
----------
coord
The coordinates of the atoms.
shape: nf x (nloc x 3)
atype
The type of atoms. shape: nf x nloc
box
The simulation box. shape: nf x 9
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
do_atomic_virial
If calculate the atomic virial.

Returns
-------
ret_dict
The result dict of type Dict[str,torch.Tensor].
The keys are defined by the `ModelOutputDef`.

"""
ret = super().forward_common(

Check warning on line 97 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L97

Added line #L97 was not covered by tests
coord,
atype,
box=box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
vdef = self.fitting_output_def()
hess_yes = [vdef[kk].r_hessian for kk in vdef.keys()]
if any(hess_yes):
hess = self._cal_hessian_all(

Check warning on line 108 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L105-L108

Added lines #L105 - L108 were not covered by tests
coord,
atype,
box=box,
fparam=fparam,
aparam=aparam,
)
ret.update(hess)
return ret

Check warning on line 116 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L115-L116

Added lines #L115 - L116 were not covered by tests

def _cal_hessian_all(

Check warning on line 118 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L118

Added line #L118 was not covered by tests
self,
coord: torch.Tensor,
atype: torch.Tensor,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
nf, nloc = atype.shape
coord = coord.view([nf, (nloc * 3)])
box = box.view([nf, 9]) if box is not None else None
fparam = fparam.view([nf, -1]) if fparam is not None else None
aparam = aparam.view([nf, nloc, -1]) if aparam is not None else None
fdef = self.fitting_output_def()

Check warning on line 131 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L126-L131

Added lines #L126 - L131 were not covered by tests
# keys of values that require hessian
hess_keys: List[str] = []
for kk in fdef.keys():
if fdef[kk].r_hessian:
hess_keys.append(kk)

Check warning on line 136 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L133-L136

Added lines #L133 - L136 were not covered by tests
# result dict init by empty lists
res = {get_hessian_name(kk): [] for kk in hess_keys}

Check warning on line 138 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L138

Added line #L138 was not covered by tests
# loop over variable
for kk in hess_keys:
vdef = fdef[kk]
vshape = vdef.shape
vsize = math.prod(vdef.shape)

Check warning on line 143 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L140-L143

Added lines #L140 - L143 were not covered by tests
# loop over frames
for ii in range(nf):
icoord = coord[ii]
iatype = atype[ii]
ibox = box[ii] if box is not None else None
ifparam = fparam[ii] if fparam is not None else None
iaparam = aparam[ii] if aparam is not None else None

Check warning on line 150 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L145-L150

Added lines #L145 - L150 were not covered by tests
# loop over all components
for idx in range(vsize):
hess = self._cal_hessian_one_component(

Check warning on line 153 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L152-L153

Added lines #L152 - L153 were not covered by tests
idx, icoord, iatype, ibox, ifparam, iaparam
)
res[get_hessian_name(kk)].append(hess)
res[get_hessian_name(kk)] = torch.stack(res[get_hessian_name(kk)]).view(

Check warning on line 157 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L156-L157

Added lines #L156 - L157 were not covered by tests
(nf, *vshape, nloc * 3, nloc * 3)
)
return res

Check warning on line 160 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L160

Added line #L160 was not covered by tests

def _cal_hessian_one_component(

Check warning on line 162 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L162

Added line #L162 was not covered by tests
self,
ci,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# coord, # (nloc x 3)
# atype, # nloc
# box: Optional[torch.Tensor] = None, # 9
# fparam: Optional[torch.Tensor] = None, # nfp
# aparam: Optional[torch.Tensor] = None, # (nloc x nap)
wc = wrapper_class_forward_energy(self, ci, atype, box, fparam, aparam)

Check warning on line 176 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L176

Added line #L176 was not covered by tests

hess = torch.autograd.functional.hessian(

Check warning on line 178 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L178

Added line #L178 was not covered by tests
wc,
coord,
create_graph=False,
)
return hess

Check warning on line 183 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L183

Added line #L183 was not covered by tests

class wrapper_class_forward_energy:
def __init__(

Check warning on line 186 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L185-L186

Added lines #L185 - L186 were not covered by tests
self,
obj: CM,
ci: int,
atype: torch.Tensor,
box: Optional[torch.Tensor],
fparam: Optional[torch.Tensor],
aparam: Optional[torch.Tensor],
):
self.atype, self.box, self.fparam, self.aparam = atype, box, fparam, aparam
self.ci = ci
self.obj = obj

Check warning on line 197 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L195-L197

Added lines #L195 - L197 were not covered by tests

def __call__(

Check warning on line 199 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L199

Added line #L199 was not covered by tests
self,
xx,
):
ci = self.ci
atype, box, fparam, aparam = self.atype, self.box, self.fparam, self.aparam
res = super(CM, self.obj).forward_common(

Check warning on line 205 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L203-L205

Added lines #L203 - L205 were not covered by tests
xx.unsqueeze(0),
atype.unsqueeze(0),
box.unsqueeze(0) if box is not None else None,
fparam.unsqueeze(0) if fparam is not None else None,
aparam.unsqueeze(0) if aparam is not None else None,
do_atomic_virial=False,
)
er = res["energy_redu"][0].view([-1])[ci]
return er

Check warning on line 214 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L213-L214

Added lines #L213 - L214 were not covered by tests

return CM

Check warning on line 216 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L216

Added line #L216 was not covered by tests
8 changes: 8 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@
The type of atoms. shape: nf x nloc
box
The simulation box. shape: nf x 9
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
do_atomic_virial
If calculate the atomic virial.

Expand Down Expand Up @@ -155,6 +159,10 @@
neighbor list. nf x nloc x nsel.
mapping
mapps the extended indices to local indices. nf x nall.
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda

Check warning on line 165 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L164-L165

Added lines #L164 - L165 were not covered by tests
do_atomic_virial
whether calculate atomic virial.

Expand Down
7 changes: 4 additions & 3 deletions deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@
assert aviri is not None
aviri = aviri.unsqueeze(-2)
split_avir.append(aviri)
# nf x nloc x v_dim x 3, nf x nloc x v_dim x 9
ff = torch.concat(split_ff, dim=-2)
# nf x nall x v_dim x 3, nf x nall x v_dim x 9
out_lead_shape = list(coord_ext.shape[:-1]) + vdef.shape
ff = torch.concat(split_ff, dim=-2).view(out_lead_shape + [3]) # noqa: RUF005

Check warning on line 133 in deepmd/pt/model/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/transform_output.py#L132-L133

Added lines #L132 - L133 were not covered by tests
if do_virial:
avir = torch.concat(split_avir, dim=-2)
avir = torch.concat(split_avir, dim=-2).view(out_lead_shape + [9]) # noqa: RUF005

Check warning on line 135 in deepmd/pt/model/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/transform_output.py#L135

Added line #L135 was not covered by tests
else:
avir = None
return ff, avir
Expand Down
Loading
Loading