Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add zbl training #3398

Merged
merged 71 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
063f8e7
feat: add zbl training
anyangml Mar 3, 2024
8f06ab0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
2f7fa77
fix: add atom bias
anyangml Mar 3, 2024
672563c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
312973a
Merge branch 'devel' into devel
anyangml Mar 3, 2024
cf66829
Merge branch 'devel' into devel
anyangml Mar 3, 2024
52ab95f
chore: refactor
anyangml Mar 3, 2024
993efe9
fix: add pairtab stat
anyangml Mar 3, 2024
897f9f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
e8320b6
Merge branch 'devel' into devel
anyangml Mar 3, 2024
701cb55
fix: add UTs
anyangml Mar 3, 2024
e27a816
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
dc30bbd
fix: add UT input
anyangml Mar 3, 2024
a232cf3
fix: UTs
anyangml Mar 3, 2024
d9856e7
Merge branch 'devel' into devel
anyangml Mar 3, 2024
004b63e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
ca99701
fix: UTs
anyangml Mar 3, 2024
162fc16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
9c25175
fix: UTs
anyangml Mar 3, 2024
8fc3a70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
84fb816
chore: merge conflict
anyangml Mar 3, 2024
55e2b7f
fix: update numpy shape
anyangml Mar 3, 2024
0b9f7ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
6524694
fix: UTs
anyangml Mar 3, 2024
e3d9a7b
feat: add UTs
anyangml Mar 3, 2024
e648ab4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
7143aa9
Merge branch 'devel' into devel
anyangml Mar 4, 2024
6ed8fde
fix: UTs
anyangml Mar 4, 2024
aadddcb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
f36988d
fix: UTs
anyangml Mar 4, 2024
9c9cbbe
feat: update UTs
anyangml Mar 4, 2024
d2adebb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
7071608
Merge branch 'devel' into devel
anyangml Mar 4, 2024
00c877c
fix: UTs
anyangml Mar 4, 2024
eb36de2
Merge branch 'devel' into devel
anyangml Mar 4, 2024
5de7214
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
8b35fa4
rix: revert abstract method
anyangml Mar 4, 2024
bbc7ad2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
c384b3b
fix: UTs
anyangml Mar 4, 2024
1d5fad0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
dc407e3
chore: refactor
anyangml Mar 4, 2024
a9f65be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
18a4897
fix: precommit
anyangml Mar 4, 2024
94bea6a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
09f9352
fix: precommit
anyangml Mar 4, 2024
a63089d
fix: UTs
anyangml Mar 4, 2024
bda547d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
e6be71b
fix: UTs
anyangml Mar 4, 2024
3482ef2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
a2afe7c
Merge branch 'devel' into devel
anyangml Mar 4, 2024
f067c4c
Merge branch 'devel' into devel
anyangml Mar 5, 2024
b0e4749
Merge branch 'devel' into devel
anyangml Mar 5, 2024
f8e340a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
3610b3d
feat: add atype remap
anyangml Mar 5, 2024
caf5f78
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
7d4e49c
fix: add UTs
anyangml Mar 5, 2024
a30bc35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
d436444
fix: UTs
anyangml Mar 5, 2024
af1349c
Merge branch 'devel' into devel
anyangml Mar 5, 2024
ba643b9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
b60f038
fix: update numpy
anyangml Mar 5, 2024
e5a905b
chore:skip test
anyangml Mar 5, 2024
25f1ff8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
85baa59
chore: rename class
anyangml Mar 5, 2024
a29541c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
1b04a43
Merge branch 'devel' into devel
anyangml Mar 6, 2024
18ec6a5
fix: add TODO
anyangml Mar 6, 2024
d62bacb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
c012b3f
chore: refactor remap
anyangml Mar 6, 2024
a0d7caf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
edd9a10
fix: UTs
anyangml Mar 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
get_multiple_nlist_key,
nlist_distinguish_types,
)
from deepmd.utils.path import (

Check warning on line 28 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L28

Added line #L28 was not covered by tests
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -184,14 +187,17 @@

weights = self._compute_weight(extended_coord, extended_atype, nlists_)

if self.atomic_bias is not None:
raise NotImplementedError("Need to add bias in a future PR.")
else:
fit_ret = {
"energy": torch.sum(
torch.stack(ener_list) * torch.stack(weights), dim=0
),
} # (nframes, nloc, 1)
atype = extended_atype[:, :nloc]
for idx, m in enumerate(self.models):
if isinstance(m, DPAtomicModel) and m.fitting_net is not None:
bias_atom_e = m.fitting_net.bias_atom_e
elif isinstance(m, PairTabAtomicModel):
bias_atom_e = m.bias_atom_e
ener_list[idx] += bias_atom_e[atype]

Check warning on line 196 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L190-L196

Added lines #L190 - L196 were not covered by tests

fit_ret = {

Check warning on line 198 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L198

Added line #L198 was not covered by tests
"energy": torch.sum(torch.stack(ener_list) * torch.stack(weights), dim=0),
} # (nframes, nloc, 1)
return fit_ret

def fitting_output_def(self) -> FittingOutputDef:
Expand Down Expand Up @@ -307,6 +313,29 @@
# this is a placeholder being updated in _compute_weight, to handle Jit attribute init error.
self.zbl_weight = torch.empty(0, dtype=torch.float64, device=env.DEVICE)

def compute_or_load_stat(

Check warning on line 316 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L316

Added line #L316 was not covered by tests
self,
sampled_func,
stat_file_path: Optional[DPPath] = None,
):
"""
Compute or load the statistics parameters of the model,
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
and saved in the `stat_file_path`(s).
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
and load the calculated statistics parameters.

Parameters
----------
sampled_func
The lazy sampled function to get data frames from different data systems.
stat_file_path
The dictionary of paths to the statistics files.
"""
self.dp_model.compute_or_load_stat(sampled_func, stat_file_path)
self.zbl_model.compute_output_stats(sampled_func, stat_file_path)

Check warning on line 337 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L336-L337

Added lines #L336 - L337 were not covered by tests

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
74 changes: 74 additions & 0 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Callable,
Dict,
List,
Optional,
Expand All @@ -13,9 +14,21 @@
FittingOutputDef,
OutputVariableDef,
)
from deepmd.pt.utils import (

Check warning on line 17 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L17

Added line #L17 was not covered by tests
env,
)
from deepmd.pt.utils.utils import (

Check warning on line 20 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L20

Added line #L20 was not covered by tests
to_numpy_array,
)
from deepmd.utils.out_stat import (

Check warning on line 23 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L23

Added line #L23 was not covered by tests
compute_stats_from_redu,
)
from deepmd.utils.pair_tab import (
PairTab,
)
from deepmd.utils.path import (

Check warning on line 29 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L29

Added line #L29 was not covered by tests
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -57,6 +70,7 @@
self.tab_file = tab_file
self.rcut = rcut
self.tab = self._set_pairtab(tab_file, rcut)

BaseAtomicModel.__init__(self, **kwargs)

# handle deserialization with no input file
Expand All @@ -70,6 +84,7 @@
else:
self.register_buffer("tab_info", None)
self.register_buffer("tab_data", None)
self.bias_atom_e = None

Check warning on line 87 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L87

Added line #L87 was not covered by tests
anyangml marked this conversation as resolved.
Show resolved Hide resolved

# self.model_type = "ener"
# self.model_version = MODEL_VERSION ## this shoud be in the parent class
Expand Down Expand Up @@ -154,6 +169,65 @@
tab_model.register_buffer("tab_data", torch.from_numpy(tab_model.tab.tab_data))
return tab_model

def compute_output_stats(

Check warning on line 172 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L172

Added line #L172 was not covered by tests
self,
anyangml marked this conversation as resolved.
Show resolved Hide resolved
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
):
"""
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.

"""
if stat_file_path is not None:
stat_file_path = stat_file_path / "bias_atom_e"
if stat_file_path is not None and stat_file_path.is_file():
bias_atom_e = stat_file_path.load_numpy()

Check warning on line 196 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L193-L196

Added lines #L193 - L196 were not covered by tests
else:
if callable(merged):

Check warning on line 198 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L198

Added line #L198 was not covered by tests
# only get data for once
sampled = merged()

Check warning on line 200 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L200

Added line #L200 was not covered by tests
else:
sampled = merged
energy = [item["energy"] for item in sampled]
data_mixed_type = "real_natoms_vec" in sampled[0]
if data_mixed_type:
input_natoms = [item["real_natoms_vec"] for item in sampled]

Check warning on line 206 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L202-L206

Added lines #L202 - L206 were not covered by tests
else:
input_natoms = [item["natoms"] for item in sampled]

Check warning on line 208 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L208

Added line #L208 was not covered by tests
# shape: (nframes, ndim)
merged_energy = to_numpy_array(torch.cat(energy))

Check warning on line 210 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L210

Added line #L210 was not covered by tests
# shape: (nframes, ntypes)
merged_natoms = to_numpy_array(torch.cat(input_natoms)[:, 2:])

Check warning on line 212 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L212

Added line #L212 was not covered by tests

bias_atom_e, _ = compute_stats_from_redu(

Check warning on line 214 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L214

Added line #L214 was not covered by tests
merged_energy,
merged_natoms,
assigned_bias=None,
rcond=None,
anyangml marked this conversation as resolved.
Show resolved Hide resolved
)
if stat_file_path is not None:
stat_file_path.save_numpy(bias_atom_e)
assert all(x is not None for x in [bias_atom_e])
ntypes = merged_natoms.shape[1]
self.bias_atom_e = torch.empty(

Check warning on line 224 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L220-L224

Added lines #L220 - L224 were not covered by tests
[ntypes, 1], dtype=torch.float64, device=env.DEVICE
)
self.bias_atom_e.copy_(

Check warning on line 227 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L227

Added line #L227 was not covered by tests
torch.tensor(bias_atom_e, device=env.DEVICE).view([ntypes, 1])
)

def forward_atomic(
self,
extended_coord: torch.Tensor,
Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from deepmd.pt.model.model import (
get_model,
get_zbl_model,
)
from deepmd.pt.optimizer import (
KFOptimizerWrapper,
Expand Down Expand Up @@ -243,7 +244,10 @@
def get_single_model(
_model_params,
):
model = get_model(deepcopy(_model_params)).to(DEVICE)
if "use_srtab" in _model_params:
model = get_zbl_model(deepcopy(_model_params)).to(DEVICE)

Check warning on line 248 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L247-L248

Added lines #L247 - L248 were not covered by tests
else:
model = get_model(deepcopy(_model_params)).to(DEVICE)

Check warning on line 250 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L250

Added line #L250 was not covered by tests
anyangml marked this conversation as resolved.
Show resolved Hide resolved
return model

def get_lr(lr_params):
Expand Down
Loading