Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add se_r descriptor #3338

Merged
merged 23 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
af51721
feat: add se_r descriptor
anyangml Feb 26, 2024
c0af6fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
8a8107c
fix: UTs, removed old impl
anyangml Feb 26, 2024
fb6340b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
2eb4041
fix: pre-commit
anyangml Feb 26, 2024
6c5224f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
b771db4
fix: update se_r output
anyangml Feb 26, 2024
8a1a86c
chore: refactor
anyangml Feb 26, 2024
50cdfe0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
84d61da
feat: add numpy impl
anyangml Feb 26, 2024
08a6988
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
1f0fd99
fix: UTs
anyangml Feb 27, 2024
c07e02c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2024
3485608
Merge branch 'devel' into devel
anyangml Feb 27, 2024
e5b074e
gix: match serialization
anyangml Feb 27, 2024
8265242
Merge branch 'devel' into devel
anyangml Feb 27, 2024
3fe3ed4
chore: refactor device
anyangml Feb 27, 2024
9d23e96
chore: refactor device
anyangml Feb 27, 2024
c073241
fix: UTs
anyangml Feb 27, 2024
ce85c0f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2024
72d7a37
fix: dtype
anyangml Feb 27, 2024
54375bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2024
184e6a0
Merge branch 'devel' into devel
anyangml Feb 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from .env_mat import (
prod_env_mat_se_a,
prod_env_mat_se_r,
)
from .gaussian_lcc import (
DescrptGaussianLcc,
Expand All @@ -27,6 +28,9 @@
DescrptBlockSeA,
DescrptSeA,
)
from .se_r import (
DescrptSeR,
)

__all__ = [
"Descriptor",
Expand All @@ -35,9 +39,11 @@
"DescrptBlockSeA",
"DescrptBlockSeAtten",
"DescrptSeA",
"DescrptSeR",
"DescrptDPA1",
"DescrptDPA2",
"prod_env_mat_se_a",
"prod_env_mat_se_r",
"DescrptGaussianLcc",
"DescrptBlockHybrid",
"DescrptBlockRepformers",
Expand Down
53 changes: 53 additions & 0 deletions deepmd/pt/model/descriptor/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,29 @@
return env_mat_se_a, diff * mask.unsqueeze(-1), weight


def _make_env_mat_se_r(nlist, coord, rcut: float, ruct_smth: float):
"""Make smooth environment matrix."""
bsz, natoms, nnei = nlist.shape
coord = coord.view(bsz, -1, 3)
nall = coord.shape[1]
mask = nlist >= 0
# nlist = nlist * mask ## this impl will contribute nans in Hessian calculation.
nlist = torch.where(mask, nlist, nall - 1)
coord_l = coord[:, :natoms].view(bsz, -1, 1, 3)
index = nlist.view(bsz, -1).unsqueeze(-1).expand(-1, -1, 3)
coord_r = torch.gather(coord, 1, index)
coord_r = coord_r.view(bsz, natoms, nnei, 3)
diff = coord_r - coord_l
length = torch.linalg.norm(diff, dim=-1, keepdim=True)
# for index 0 nloc atom
length = length + ~mask.unsqueeze(-1)
t0 = 1 / length
weight = compute_smooth_weight(length, ruct_smth, rcut)
weight = weight * mask.unsqueeze(-1)
env_mat_se_r = t0 * weight
return env_mat_se_r, diff * mask.unsqueeze(-1), weight


def prod_env_mat_se_a(
extended_coord, nlist, atype, mean, stddev, rcut: float, rcut_smth: float
):
Expand Down Expand Up @@ -58,3 +81,33 @@
t_std = stddev[atype] # [n_atom, dim, 4]
env_mat_se_a = (_env_mat_se_a - t_avg) / t_std
return env_mat_se_a, diff, switch


def prod_env_mat_se_r(
extended_coord, nlist, atype, mean, stddev, rcut: float, rcut_smth: float
):
"""Generate smooth environment matrix from atom coordinates and other context.

Args:
- extended_coord: Copied atom coordinates with shape [nframes, nall*3].
- atype: Atom types with shape [nframes, nloc].
- natoms: Batched atom statisics with shape [len(sec)+2].
- box: Batched simulation box with shape [nframes, 9].
- mean: Average value of descriptor per element type with shape [len(sec), nnei, 1].
- stddev: Standard deviation of descriptor per element type with shape [len(sec), nnei, 1].
- deriv_stddev: StdDev of descriptor derivative per element type with shape [len(sec), nnei, 1, 3].
- rcut: Cut-off radius.
- rcut_smth: Smooth hyper-parameter for pair force & energy.

Returns
-------
- env_mat_se_r: Shape is [nframes, natoms[1]*nnei*1].
"""
nframes = extended_coord.shape[0]
_env_mat_se_r, diff, switch = _make_env_mat_se_r(
nlist, extended_coord, rcut, rcut_smth
) # shape [n_atom, dim, 1]
t_avg = mean[atype] # [n_atom, dim, 1]
t_std = stddev[atype] # [n_atom, dim, 1]
env_mat_se_r = (_env_mat_se_r - t_avg) / t_std
return env_mat_se_r, diff, switch
Loading
Loading