From e41b091e8c5be18b98c375ee8ddbbe55f5552f63 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 14 Feb 2024 00:20:48 +0800 Subject: [PATCH] breaking: pt: remove data preprocess from data stat (#3261) This PR: - Remove data preprocess from data stat. - Cleanup dependency of data preprocess in dataset and dataloader. Note that: - `DeepmdDataSystem` still has dependency for PyTorch, which leaves for @CaRoLZhangxy to clean up. - Denoise part in `DeepmdDataSystem` still needs further clean up, which leaves for @Chengqian-Zhang. --------- Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/pt/model/descriptor/dpa1.py | 8 +- deepmd/pt/model/descriptor/dpa2.py | 11 +- deepmd/pt/model/descriptor/hybrid.py | 10 +- deepmd/pt/model/descriptor/repformers.py | 46 +-- deepmd/pt/model/descriptor/se_a.py | 46 ++- deepmd/pt/model/descriptor/se_atten.py | 42 ++- deepmd/pt/model/model/make_model.py | 27 +- deepmd/pt/model/task/ener.py | 2 +- deepmd/pt/utils/dataloader.py | 21 +- deepmd/pt/utils/dataset.py | 290 +----------------- deepmd/pt/utils/nlist.py | 32 ++ deepmd/pt/utils/stat.py | 30 +- source/tests/pt/model/test_descriptor.py | 52 +++- source/tests/pt/model/test_descriptor_dpa1.py | 48 ++- source/tests/pt/model/test_descriptor_dpa2.py | 81 ++--- source/tests/pt/model/test_dp_model.py | 12 +- source/tests/pt/model/test_embedding_net.py | 55 +++- source/tests/pt/model/test_model.py | 4 +- source/tests/pt/test_loss.py | 13 +- source/tests/pt/test_stat.py | 5 +- 20 files changed, 325 insertions(+), 510 deletions(-) diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 6c1331ec1d..76cff174af 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import logging from typing import ( List, Optional, @@ -18,8 +17,6 @@ DescrptBlockSeAtten, ) -log = logging.getLogger(__name__) - @Descriptor.register("dpa1") @Descriptor.register("se_atten") @@ -112,7 +109,7 @@ def distinguish_types(self) -> bool: """Returns if the descriptor requires a neighbor list that distinguish different atomic types or not. """ - return False + return self.se_atten.distinguish_types() @property def dim_out(self): @@ -128,7 +125,7 @@ def compute_input_stats(self, merged): def init_desc_stat( self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs ): - assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]] + assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) self.se_atten.init_desc_stat(sumr, suma, sumn, sumr2, suma2) @classmethod @@ -141,6 +138,7 @@ def get_stat_name( """ descrpt_type = type_name assert descrpt_type in ["dpa1", "se_atten"] + assert all(x is not None for x in [rcut, rcut_smth, sel]) return f"stat_file_descrpt_dpa1_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz" @classmethod diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 05e7cec658..6cefaf6f38 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import logging from typing import ( List, Optional, @@ -27,8 +26,6 @@ DescrptBlockSeAtten, ) -log = logging.getLogger(__name__) - @Descriptor.register("dpa2") class DescrptDPA2(Descriptor): @@ -316,7 +313,7 @@ def compute_input_stats(self, merged): def init_desc_stat( self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs ): - assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]] + assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) for ii, descrpt in enumerate([self.repinit, self.repformers]): stat_dict_ii = { "sumr": sumr[ii], @@ -346,8 +343,8 @@ def get_stat_name( """ descrpt_type = type_name assert descrpt_type in ["dpa2"] - assert True not in [ - x is None + assert all( + x is not None for x in [ repinit_rcut, repinit_rcut_smth, @@ -356,7 +353,7 @@ def get_stat_name( repformer_rcut_smth, repformer_nsel, ] - ] + ) return ( f"stat_file_descrpt_dpa2_repinit_rcut{repinit_rcut:.2f}_smth{repinit_rcut_smth:.2f}_sel{repinit_nsel}" f"_repformer_rcut{repformer_rcut:.2f}_smth{repformer_rcut_smth:.2f}_sel{repformer_nsel}_ntypes{ntypes}.npz" diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index fb7e374ede..c5c08c760d 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -103,6 +103,14 @@ def get_dim_in(self) -> int: def get_dim_emb(self): return self.dim_emb + def distinguish_types(self) -> bool: + """Returns if the descriptor requires a neighbor list that distinguish different + atomic types or not. + """ + return any( + descriptor.distinguish_types() for descriptor in self.descriptor_list + ) + @property def dim_out(self): """Returns the output dimension of this descriptor.""" @@ -170,7 +178,7 @@ def compute_input_stats(self, merged): def init_desc_stat( self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs ): - assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]] + assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) for ii, descrpt in enumerate(self.descriptor_list): stat_dict_ii = { "sumr": sumr[ii], diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 0a302b6f92..26467124b8 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -21,7 +21,7 @@ env, ) from deepmd.pt.utils.nlist import ( - build_neighbor_list, + extend_input_and_build_neighbor_list, ) from deepmd.pt.utils.utils import ( get_activation_fn, @@ -178,6 +178,12 @@ def get_dim_emb(self) -> int: """Returns the embedding dimension g2.""" return self.g2_dim + def distinguish_types(self) -> bool: + """Returns if the descriptor requires a neighbor list that distinguish different + atomic types or not. + """ + return False + @property def dim_out(self): """Returns the output dimension of this descriptor.""" @@ -272,28 +278,29 @@ def compute_input_stats(self, merged): suma2 = [] mixed_type = "real_natoms_vec" in merged[0] for system in merged: - index = system["mapping"].unsqueeze(-1).expand(-1, -1, 3) - extended_coord = torch.gather(system["coord"], dim=1, index=index) - extended_coord = extended_coord - system["shift"] - index = system["mapping"] - extended_atype = torch.gather(system["atype"], dim=1, index=index) - nloc = system["atype"].shape[-1] - ####################################################### - # dirty hack here! the interface of dataload should be - # redesigned to support descriptors like dpa2 - ####################################################### - nlist = build_neighbor_list( + coord, atype, box, natoms = ( + system["coord"], + system["atype"], + system["box"], + system["natoms"], + ) + ( extended_coord, extended_atype, - nloc, - self.rcut, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + self.get_rcut(), self.get_sel(), - distinguish_types=False, + distinguish_types=self.distinguish_types(), + box=box, ) env_mat, _, _ = prod_env_mat_se_a( extended_coord, nlist, - system["atype"], + atype, self.mean, self.stddev, self.rcut, @@ -301,15 +308,16 @@ def compute_input_stats(self, merged): ) if not mixed_type: sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( - env_mat.detach().cpu().numpy(), ndescrpt, system["natoms"] + env_mat.detach().cpu().numpy(), ndescrpt, natoms ) else: + real_natoms_vec = system["real_natoms_vec"] sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( env_mat.detach().cpu().numpy(), ndescrpt, - system["real_natoms_vec"], + real_natoms_vec, mixed_type=mixed_type, - real_atype=system["atype"].detach().cpu().numpy(), + real_atype=atype.detach().cpu().numpy(), ) sumr.append(sysr) suma.append(sysa) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 82e7e5185a..700bf6d59b 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import logging from typing import ( ClassVar, List, @@ -37,8 +36,9 @@ from deepmd.pt.model.network.network import ( TypeFilter, ) - -log = logging.getLogger(__name__) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) @Descriptor.register("se_e2_a") @@ -100,7 +100,7 @@ def distinguish_types(self): """Returns if the descriptor requires a neighbor list that distinguish different atomic types or not. """ - return True + return self.sea.distinguish_types() @property def dim_out(self): @@ -114,7 +114,7 @@ def compute_input_stats(self, merged): def init_desc_stat( self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs ): - assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]] + assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) self.sea.init_desc_stat(sumr, suma, sumn, sumr2, suma2) @classmethod @@ -127,7 +127,7 @@ def get_stat_name( """ descrpt_type = type_name assert descrpt_type in ["se_e2_a"] - assert True not in [x is None for x in [rcut, rcut_smth, sel]] + assert all(x is not None for x in [rcut, rcut_smth, sel]) return f"stat_file_descrpt_sea_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz" @classmethod @@ -347,6 +347,12 @@ def get_dim_in(self) -> int: """Returns the input dimension.""" return self.dim_in + def distinguish_types(self) -> bool: + """Returns if the descriptor requires a neighbor list that distinguish different + atomic types or not. + """ + return True + @property def dim_out(self): """Returns the output dimension of this descriptor.""" @@ -381,20 +387,36 @@ def compute_input_stats(self, merged): sumr2 = [] suma2 = [] for system in merged: - index = system["mapping"].unsqueeze(-1).expand(-1, -1, 3) - extended_coord = torch.gather(system["coord"], dim=1, index=index) - extended_coord = extended_coord - system["shift"] + coord, atype, box, natoms = ( + system["coord"], + system["atype"], + system["box"], + system["natoms"], + ) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + self.get_rcut(), + self.get_sel(), + distinguish_types=self.distinguish_types(), + box=box, + ) env_mat, _, _ = prod_env_mat_se_a( extended_coord, - system["nlist"], - system["atype"], + nlist, + atype, self.mean, self.stddev, self.rcut, self.rcut_smth, ) sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( - env_mat.detach().cpu().numpy(), self.ndescrpt, system["natoms"] + env_mat.detach().cpu().numpy(), self.ndescrpt, natoms ) sumr.append(sysr) suma.append(sysa) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 3469d43e40..d4dc0cd054 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -21,6 +21,9 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) @DescriptorBlock.register("se_atten") @@ -161,6 +164,12 @@ def get_dim_emb(self) -> int: """Returns the output dimension of embedding.""" return self.filter_neuron[-1] + def distinguish_types(self) -> bool: + """Returns if the descriptor requires a neighbor list that distinguish different + atomic types or not. + """ + return False + @property def dim_out(self): """Returns the output dimension of this descriptor.""" @@ -185,13 +194,29 @@ def compute_input_stats(self, merged): suma2 = [] mixed_type = "real_natoms_vec" in merged[0] for system in merged: - index = system["mapping"].unsqueeze(-1).expand(-1, -1, 3) - extended_coord = torch.gather(system["coord"], dim=1, index=index) - extended_coord = extended_coord - system["shift"] + coord, atype, box, natoms = ( + system["coord"], + system["atype"], + system["box"], + system["natoms"], + ) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + self.get_rcut(), + self.get_sel(), + distinguish_types=self.distinguish_types(), + box=box, + ) env_mat, _, _ = prod_env_mat_se_a( extended_coord, - system["nlist"], - system["atype"], + nlist, + atype, self.mean, self.stddev, self.rcut, @@ -199,15 +224,16 @@ def compute_input_stats(self, merged): ) if not mixed_type: sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( - env_mat.detach().cpu().numpy(), self.ndescrpt, system["natoms"] + env_mat.detach().cpu().numpy(), self.ndescrpt, natoms ) else: + real_natoms_vec = system["real_natoms_vec"] sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( env_mat.detach().cpu().numpy(), self.ndescrpt, - system["real_natoms_vec"], + real_natoms_vec, mixed_type=mixed_type, - real_atype=system["atype"].detach().cpu().numpy(), + real_atype=atype.detach().cpu().numpy(), ) sumr.append(sysr) suma.append(sysa) diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 9191f8c58f..a68ddd45a5 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -14,13 +14,9 @@ fit_output_to_model_output, ) from deepmd.pt.utils.nlist import ( - build_neighbor_list, - extend_coord_with_ghosts, + extend_input_and_build_neighbor_list, nlist_distinguish_types, ) -from deepmd.pt.utils.region import ( - normalize_coord, -) def make_model(T_AtomicModel): @@ -97,26 +93,19 @@ def forward_common( The keys are defined by the `ModelOutputDef`. """ - nframes, nloc = atype.shape[:2] - if box is not None: - coord_normalized = normalize_coord( - coord.view(nframes, nloc, 3), - box.reshape(nframes, 3, 3), - ) - else: - coord_normalized = coord.clone() - extended_coord, extended_atype, mapping = extend_coord_with_ghosts( - coord_normalized, atype, box, self.get_rcut() - ) - nlist = build_neighbor_list( + ( extended_coord, extended_atype, - nloc, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, self.get_rcut(), self.get_sel(), distinguish_types=self.distinguish_types(), + box=box, ) - extended_coord = extended_coord.view(nframes, -1, 3) model_predict_lower = self.forward_common_lower( extended_coord, extended_atype, diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index c8ade925c0..1b3e2c3d65 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -223,7 +223,7 @@ def compute_output_stats(self, merged): return {"bias_atom_e": bias_atom_e} def init_fitting_stat(self, bias_atom_e=None, **kwargs): - assert True not in [x is None for x in [bias_atom_e]] + assert all(x is not None for x in [bias_atom_e]) self.bias_atom_e.copy_( torch.tensor(bias_atom_e, device=env.DEVICE).view( [self.ntypes, self.dim_out] diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 0ec43f5a75..7f8cf4eb3c 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -267,26 +267,7 @@ def collate_batch(batch): example = batch[0] result = example.copy() for key in example.keys(): - if key == "shift" or key == "mapping": - natoms_extended = max([d[key].shape[0] for d in batch]) - n_frames = len(batch) - list = [] - for x in range(n_frames): - list.append(batch[x][key]) - if key == "shift": - result[key] = torch.zeros( - (n_frames, natoms_extended, 3), - dtype=env.GLOBAL_PT_FLOAT_PRECISION, - ) - else: - result[key] = torch.zeros( - (n_frames, natoms_extended), - dtype=torch.long, - ) - for i in range(len(batch)): - natoms_tmp = list[i].shape[0] - result[key][i, :natoms_tmp] = list[i] - elif "find_" in key: + if "find_" in key: result[key] = batch[0][key] else: if batch[0][key] is None: diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index aca4a9ce5b..60055ebda9 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -9,7 +9,6 @@ import h5py import numpy as np import torch -import torch.distributed as dist from torch.utils.data import ( Dataset, ) @@ -189,85 +188,6 @@ def add( "high_prec": high_prec, } - # deprecated TODO - def get_batch_for_train(self, batch_size: int): - """Get a batch of data with at most `batch_size` frames. The frames are randomly picked from the data system. - - Args: - - batch_size: Frame count. - """ - if not hasattr(self, "_frames"): - self.set_size = 0 - self._set_count = 0 - self._iterator = 0 - if batch_size == "auto": - batch_size = -(-32 // self._natoms) - if self._iterator + batch_size > self.set_size: - set_idx = self._set_count % len(self._dirs) - if self.sets[set_idx] is None: - frames = self._load_set(self._dirs[set_idx]) - frames = self.preprocess(frames) - cnt = 0 - for item in self.sets: - if item is not None: - cnt += 1 - if cnt < env.CACHE_PER_SYS: - self.sets[set_idx] = frames - else: - frames = self.sets[set_idx] - self._frames = frames - self._shuffle_data() - if dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - ssize = self._frames["coord"].shape[0] - subsize = ssize // world_size - self._iterator = rank * subsize - self.set_size = min((rank + 1) * subsize, ssize) - else: - self.set_size = self._frames["coord"].shape[0] - self._iterator = 0 - self._set_count += 1 - iterator = min(self._iterator + batch_size, self.set_size) - idx = np.arange(self._iterator, iterator) - self._iterator += batch_size - return self._get_subdata(idx) - - # deprecated TODO - def get_batch(self, batch_size: int): - """Get a batch of data with at most `batch_size` frames. The frames are randomly picked from the data system. - Args: - - batch_size: Frame count. - """ - if not hasattr(self, "_frames"): - self.set_size = 0 - self._set_count = 0 - self._iterator = 0 - if batch_size == "auto": - batch_size = -(-32 // self._natoms) - if self._iterator + batch_size > self.set_size: - set_idx = self._set_count % len(self._dirs) - if self.sets[set_idx] is None: - frames = self._load_set(self._dirs[set_idx]) - frames = self.preprocess(frames) - cnt = 0 - for item in self.sets: - if item is not None: - cnt += 1 - if cnt < env.CACHE_PER_SYS: - self.sets[set_idx] = frames - else: - frames = self.sets[set_idx] - self._frames = frames - self._shuffle_data() - self.set_size = self._frames["coord"].shape[0] - self._iterator = 0 - self._set_count += 1 - iterator = min(self._iterator + batch_size, self.set_size) - idx = np.arange(self._iterator, iterator) - self._iterator += batch_size - return self._get_subdata(idx) - def get_ntypes(self): """Number of atom types in the system.""" if self._type_map is not None: @@ -470,63 +390,6 @@ def _load_data( data = np.zeros([nframes, ndof]).astype(env.GLOBAL_NP_FLOAT_PRECISION) return np.float32(0.0), data - # deprecated TODO - def preprocess(self, batch): - n_frames = batch["coord"].shape[0] - for kk in self._data_dict.keys(): - if "find_" in kk: - pass - else: - batch[kk] = torch.tensor(batch[kk], dtype=env.GLOBAL_PT_FLOAT_PRECISION) - if self._data_dict[kk]["atomic"]: - batch[kk] = batch[kk].view( - n_frames, -1, self._data_dict[kk]["ndof"] - ) - - for kk in ["type", "real_natoms_vec"]: - if kk in batch.keys(): - batch[kk] = torch.tensor(batch[kk], dtype=torch.long) - batch["atype"] = batch.pop("type") - - keys = ["nlist", "nlist_loc", "nlist_type", "shift", "mapping"] - coord = batch["coord"] - atype = batch["atype"] - box = batch["box"] - rcut = self.rcut - sec = self.sec - assert batch["atype"].max() < len(self._type_map) - nlist, nlist_loc, nlist_type, shift, mapping = [], [], [], [], [] - - for sid in range(n_frames): - region = Region3D(box[sid]) - nloc = atype[sid].shape[0] - _coord = normalize_coord(coord[sid], region, nloc) - coord[sid] = _coord - a, b, c, d, e = make_env_mat( - _coord, atype[sid], region, rcut, sec, type_split=self.type_split - ) - nlist.append(a) - nlist_loc.append(b) - nlist_type.append(c) - shift.append(d) - mapping.append(e) - nlist = torch.stack(nlist) - nlist_loc = torch.stack(nlist_loc) - nlist_type = torch.stack(nlist_type) - batch["nlist"] = nlist - batch["nlist_loc"] = nlist_loc - batch["nlist_type"] = nlist_type - natoms_extended = max([item.shape[0] for item in shift]) - batch["shift"] = torch.zeros( - (n_frames, natoms_extended, 3), dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - batch["mapping"] = torch.zeros((n_frames, natoms_extended), dtype=torch.long) - for i in range(len(shift)): - natoms_tmp = shift[i].shape[0] - batch["shift"][i, :natoms_tmp] = shift[i] - batch["mapping"][i, :natoms_tmp] = mapping[i] - return batch - def _shuffle_data(self): nframes = self._frames["coord"].shape[0] idx = np.arange(nframes) @@ -563,46 +426,21 @@ def single_preprocess(self, batch, sid): for kk in ["type", "real_natoms_vec"]: if kk in batch.keys(): batch[kk] = torch.tensor(batch[kk][sid], dtype=torch.long) - clean_coord = batch.pop("coord") - clean_type = batch.pop("type") - nloc = clean_type.shape[0] + batch["atype"] = batch["type"] rcut = self.rcut sec = self.sec - nlist, nlist_loc, nlist_type, shift, mapping = [], [], [], [], [] - if self.pbc: - box = batch["box"] - region = Region3D(box) - else: - box = None + if not self.pbc: batch["box"] = None - region = None if self.noise_settings is None: - batch["atype"] = clean_type - batch["coord"] = clean_coord - coord = clean_coord - atype = batch["atype"] + return batch + else: # TODO need to clean up this method! if self.pbc: - _coord = normalize_coord(coord, region, nloc) - + region = Region3D(batch["box"]) else: - _coord = coord.clone() - batch["coord"] = _coord - nlist, nlist_loc, nlist_type, shift, mapping = make_env_mat( - _coord, - atype, - region, - rcut, - sec, - pbc=self.pbc, - type_split=self.type_split, - ) - batch["nlist"] = nlist - batch["nlist_loc"] = nlist_loc - batch["nlist_type"] = nlist_type - batch["shift"] = shift - batch["mapping"] = mapping - return batch - else: + region = None + clean_coord = batch.pop("coord") + clean_type = batch.pop("type") + nloc = clean_type.shape[0] batch["clean_type"] = clean_type if self.pbc: _clean_coord = normalize_coord(clean_coord, region, nloc) @@ -678,7 +516,7 @@ def single_preprocess(self, batch, sid): else: _coord = noised_coord.clone() try: - nlist, nlist_loc, nlist_type, shift, mapping = make_env_mat( + _ = make_env_mat( _coord, masked_type, region, @@ -694,13 +532,8 @@ def single_preprocess(self, batch, sid): f"Add noise times beyond max tries {self.max_fail_num}!" ) continue - batch["atype"] = masked_type + batch["type"] = masked_type batch["coord"] = noised_coord - batch["nlist"] = nlist - batch["nlist_loc"] = nlist_loc - batch["nlist_type"] = nlist_type - batch["shift"] = shift - batch["mapping"] = mapping return batch def _get_item(self, index): @@ -783,104 +616,3 @@ def __getitem__(self, index): b_data = self._data_system._get_item(index) b_data["natoms"] = torch.tensor(self._natoms_vec) return b_data - - -# deprecated TODO -class DeepmdDataSet(Dataset): - def __init__( - self, - systems: List[str], - batch_size: int, - type_map: List[str], - rcut=None, - sel=None, - weight=None, - type_split=True, - ): - """Construct DeePMD-style dataset containing frames cross different systems. - - Args: - - systems: Paths to systems. - - batch_size: Max frame count in a batch. - - type_map: Atom types. - """ - self._batch_size = batch_size - self._type_map = type_map - if sel is not None: - if isinstance(sel, int): - sel = [sel] - sec = torch.cumsum(torch.tensor(sel), dim=0) - if isinstance(systems, str): - with h5py.File(systems) as file: - systems = [os.path.join(systems, item) for item in file.keys()] - self._data_systems = [ - DeepmdDataSystem( - ii, rcut, sec, type_map=self._type_map, type_split=type_split - ) - for ii in systems - ] - # check mix_type format - error_format_msg = ( - "if one of the system is of mixed_type format, " - "then all of the systems in this dataset should be of mixed_type format!" - ) - self.mixed_type = self._data_systems[0].mixed_type - for sys_item in self._data_systems[1:]: - assert sys_item.mixed_type == self.mixed_type, error_format_msg - - if weight is None: - - def weight(name, sys): - return sys.nframes - - self.probs = [ - weight(item, self._data_systems[i]) for i, item in enumerate(systems) - ] - self.probs = np.array(self.probs, dtype=float) - self.probs /= self.probs.sum() - self._ntypes = max([ii.get_ntypes() for ii in self._data_systems]) - self._natoms_vec = [ - ii.get_natoms_vec(self._ntypes) for ii in self._data_systems - ] - self.cache = [{} for _ in self._data_systems] - - @property - def nsystems(self): - return len(self._data_systems) - - def __len__(self): - return self.nsystems - - def __getitem__(self, index=None): - """Get a batch of frames from the selected system.""" - if index is None: - index = dp_random.choice(np.arange(self.nsystems), p=self.probs) - b_data = self._data_systems[index].get_batch(self._batch_size) - b_data["natoms"] = torch.tensor(self._natoms_vec[index]) - batch_size = b_data["coord"].shape[0] - b_data["natoms"] = b_data["natoms"].unsqueeze(0).expand(batch_size, -1) - return b_data - - # deprecated TODO - def get_training_batch(self, index=None): - """Get a batch of frames from the selected system.""" - if index is None: - index = dp_random.choice(np.arange(self.nsystems), p=self.probs) - b_data = self._data_systems[index].get_batch_for_train(self._batch_size) - b_data["natoms"] = torch.tensor(self._natoms_vec[index]) - batch_size = b_data["coord"].shape[0] - b_data["natoms"] = b_data["natoms"].unsqueeze(0).expand(batch_size, -1) - return b_data - - def get_batch(self, sys_idx=None): - """TF-compatible batch for testing.""" - pt_batch = self[sys_idx] - np_batch = {} - for key in ["coord", "box", "force", "energy", "virial", "atype", "natoms"]: - if key in pt_batch.keys(): - np_batch[key] = pt_batch[key].cpu().numpy() - batch_size = pt_batch["coord"].shape[0] - np_batch["coord"] = np_batch["coord"].reshape(batch_size, -1) - np_batch["natoms"] = np_batch["natoms"][0] - np_batch["force"] = np_batch["force"].reshape(batch_size, -1) - return np_batch, pt_batch diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py index fdb2627f04..963c9bc9b6 100644 --- a/deepmd/pt/utils/nlist.py +++ b/deepmd/pt/utils/nlist.py @@ -12,10 +12,42 @@ env, ) from deepmd.pt.utils.region import ( + normalize_coord, to_face_distance, ) +def extend_input_and_build_neighbor_list( + coord, + atype, + rcut: float, + sel: List[int], + distinguish_types: bool = False, + box: Optional[torch.Tensor] = None, +): + nframes, nloc = atype.shape[:2] + if box is not None: + coord_normalized = normalize_coord( + coord.view(nframes, nloc, 3), + box.reshape(nframes, 3, 3), + ) + else: + coord_normalized = coord.clone() + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, box, rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + distinguish_types=distinguish_types, + ) + extended_coord = extended_coord.view(nframes, -1, 3) + return extended_coord, extended_atype, mapping, nlist + + def build_neighbor_list( coord1: torch.Tensor, atype: torch.Tensor, diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 932ba9a409..76b2afe41b 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -5,10 +5,6 @@ import numpy as np import torch -from deepmd.pt.utils import ( - env, -) - log = logging.getLogger(__name__) @@ -31,11 +27,6 @@ def make_stat_input(datasets, dataloaders, nbatches): "atype", "box", "natoms", - "mapping", - "nlist", - "nlist_loc", - "nlist_type", - "shift", ] if datasets[0].mixed_type: keys.append("real_natoms_vec") @@ -53,25 +44,6 @@ def make_stat_input(datasets, dataloaders, nbatches): if dd in keys: sys_stat[dd].append(stat_data[dd]) for key in keys: - if key == "mapping" or key == "shift": - extend = max(d.shape[1] for d in sys_stat[key]) - for jj in range(len(sys_stat[key])): - l = [] - item = sys_stat[key][jj] - for ii in range(item.shape[0]): - l.append(item[ii]) - n_frames = len(item) - if key == "shift": - shape = torch.zeros( - (n_frames, extend, 3), - dtype=env.GLOBAL_PT_FLOAT_PRECISION, - ) - else: - shape = torch.zeros((n_frames, extend), dtype=torch.long) - for i in range(len(item)): - natoms_tmp = l[i].shape[0] - shape[i, :natoms_tmp] = l[i] - sys_stat[key][jj] = shape if not isinstance(sys_stat[key][0], list): if sys_stat[key][0] is None: sys_stat[key] = None @@ -133,4 +105,4 @@ def process_stat_path( has_stat_file_path_list = [ os.path.exists(stat_file_path[key]) for key in stat_file_dict ] - return stat_file_path, False not in has_stat_file_path_list + return stat_file_path, all(has_stat_file_path_list) diff --git a/source/tests/pt/model/test_descriptor.py b/source/tests/pt/model/test_descriptor.py index 2dd996349b..a4493b5b51 100644 --- a/source/tests/pt/model/test_descriptor.py +++ b/source/tests/pt/model/test_descriptor.py @@ -21,13 +21,16 @@ env, ) from deepmd.pt.utils.dataset import ( - DeepmdDataSet, + DeepmdDataSetForLoader, ) from deepmd.pt.utils.env import ( DEVICE, GLOBAL_NP_FLOAT_PRECISION, GLOBAL_PT_FLOAT_PRECISION, ) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) from deepmd.tf.common import ( expand_sys_str, ) @@ -35,6 +38,10 @@ op_module, ) +from .test_embedding_net import ( + get_single_batch, +) + CUR_DIR = os.path.dirname(__file__) @@ -103,10 +110,14 @@ def setUp(self): self.systems = config["training"]["validation_data"]["systems"] if isinstance(self.systems, str): self.systems = expand_sys_str(self.systems) - ds = DeepmdDataSet( - self.systems, self.bsz, model_config["type_map"], self.rcut, self.sel + ds = DeepmdDataSetForLoader( + self.systems[0], + model_config["type_map"], + self.rcut, + self.sel, + type_split=True, ) - self.np_batch, self.pt_batch = ds.get_batch() + self.np_batch, self.pt_batch = get_single_batch(ds) self.sec = np.cumsum(self.sel) self.ntypes = len(self.sel) self.nnei = sum(self.sel) @@ -122,7 +133,7 @@ def test_consistency(self): dtype=GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE, ) - base_d, base_force, nlist = base_se_a( + base_d, base_force, base_nlist = base_se_a( rcut=self.rcut, rcut_smth=self.rcut_smth, sel=self.sel, @@ -132,14 +143,25 @@ def test_consistency(self): ) pt_coord = self.pt_batch["coord"].to(env.DEVICE) + atype = self.pt_batch["atype"].to(env.DEVICE) pt_coord.requires_grad_(True) - index = self.pt_batch["mapping"].unsqueeze(-1).expand(-1, -1, 3).to(env.DEVICE) - extended_coord = torch.gather(pt_coord, dim=1, index=index) - extended_coord = extended_coord - self.pt_batch["shift"].to(env.DEVICE) - my_d, _, _ = prod_env_mat_se_a( - extended_coord.to(DEVICE), - self.pt_batch["nlist"].to(env.DEVICE), + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + pt_coord, self.pt_batch["atype"].to(env.DEVICE), + self.rcut, + self.sel, + distinguish_types=True, + box=self.pt_batch["box"].to(env.DEVICE), + ) + my_d, _, _ = prod_env_mat_se_a( + extended_coord, + nlist, + atype, avg_zero.reshape([-1, self.nnei, 4]).to(DEVICE), std_ones.reshape([-1, self.nnei, 4]).to(DEVICE), self.rcut, @@ -151,16 +173,16 @@ def test_consistency(self): base_force = base_force.reshape(bsz, -1, 3) base_d = base_d.reshape(bsz, -1, self.nnei, 4) my_d = my_d.view(bsz, -1, self.nnei, 4).cpu().detach().numpy() - nlist = nlist.reshape(bsz, -1, self.nnei) + base_nlist = base_nlist.reshape(bsz, -1, self.nnei) - mapping = self.pt_batch["mapping"].cpu() - my_nlist = self.pt_batch["nlist"].view(bsz, -1).cpu() + mapping = mapping.cpu() + my_nlist = nlist.view(bsz, -1).cpu() mask = my_nlist == -1 my_nlist = my_nlist * ~mask my_nlist = torch.gather(mapping, dim=-1, index=my_nlist) my_nlist = my_nlist * ~mask - mask.long() my_nlist = my_nlist.cpu().view(bsz, -1, self.nnei).numpy() - self.assertTrue(np.allclose(nlist, my_nlist)) + self.assertTrue(np.allclose(base_nlist, my_nlist)) self.assertTrue(np.allclose(np.mean(base_d, axis=2), np.mean(my_d, axis=2))) self.assertTrue(np.allclose(np.std(base_d, axis=2), np.std(my_d, axis=2))) # descriptors may be different when there are multiple neighbors in the same distance diff --git a/source/tests/pt/model/test_descriptor_dpa1.py b/source/tests/pt/model/test_descriptor_dpa1.py index 21a43803c9..07d4d34449 100644 --- a/source/tests/pt/model/test_descriptor_dpa1.py +++ b/source/tests/pt/model/test_descriptor_dpa1.py @@ -19,11 +19,7 @@ env, ) from deepmd.pt.utils.nlist import ( - build_neighbor_list, - extend_coord_with_ghosts, -) -from deepmd.pt.utils.region import ( - normalize_coord, + extend_input_and_build_neighbor_list, ) dtype = torch.float64 @@ -245,20 +241,9 @@ def test_descriptor_block(self): **dparams, ).to(env.DEVICE) des.load_state_dict(torch.load(self.file_model_param)) - rcut = dparams["rcut"] - nsel = dparams["sel"] coord = self.coord atype = self.atype box = self.cell - nf, nloc = coord.shape[:2] - coord_normalized = normalize_coord(coord, box.reshape(-1, 3, 3)) - extended_coord, extended_atype, mapping = extend_coord_with_ghosts( - coord_normalized, atype, box, rcut - ) - # single nlist - nlist = build_neighbor_list( - extended_coord, extended_atype, nloc, rcut, nsel, distinguish_types=False - ) # handel type_embedding type_embedding = TypeEmbedNet(ntypes, 8).to(env.DEVICE) type_embedding.load_state_dict(torch.load(self.file_type_embed)) @@ -266,6 +251,19 @@ def test_descriptor_block(self): ## to save model parameters # torch.save(des.state_dict(), 'model_weights.pth') # torch.save(type_embedding.state_dict(), 'model_weights.pth') + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + des.get_rcut(), + des.get_sel(), + distinguish_types=des.distinguish_types(), + box=box, + ) descriptor, env_mat, diff, rot_mat, sw = des( nlist, extended_coord, @@ -307,18 +305,18 @@ def test_descriptor(self): coord = self.coord atype = self.atype box = self.cell - nf, nloc = coord.shape[:2] - coord_normalized = normalize_coord(coord, box.reshape(-1, 3, 3)) - extended_coord, extended_atype, mapping = extend_coord_with_ghosts( - coord_normalized, atype, box, des.get_rcut() - ) - nlist = build_neighbor_list( + ( extended_coord, extended_atype, - nloc, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, des.get_rcut(), - des.get_nsel(), - distinguish_types=False, + des.get_sel(), + distinguish_types=des.distinguish_types(), + box=box, ) descriptor, env_mat, diff, rot_mat, sw = des( extended_coord, diff --git a/source/tests/pt/model/test_descriptor_dpa2.py b/source/tests/pt/model/test_descriptor_dpa2.py index e614e64c2f..6b80eb89a2 100644 --- a/source/tests/pt/model/test_descriptor_dpa2.py +++ b/source/tests/pt/model/test_descriptor_dpa2.py @@ -19,11 +19,9 @@ env, ) from deepmd.pt.utils.nlist import ( - build_neighbor_list, - extend_coord_with_ghosts, -) -from deepmd.pt.utils.region import ( - normalize_coord, + build_multiple_neighbor_list, + extend_input_and_build_neighbor_list, + get_multiple_nlist_key, ) dtype = torch.float64 @@ -114,6 +112,7 @@ def setUp(self): self.file_model_param = Path(CUR_DIR) / "models" / "dpa2.pth" self.file_type_embed = Path(CUR_DIR) / "models" / "dpa2_tebd.pth" + # TODO This test for hybrid descriptor should be removed! def test_descriptor_hyb(self): # torch.manual_seed(0) model_hybrid_dpa2 = self.model_json @@ -129,34 +128,13 @@ def test_descriptor_hyb(self): # type_embd of repformer is removed model_dict.pop("descriptor_list.1.type_embd.embedding.weight") des.load_state_dict(model_dict) - all_rcut = [ii["rcut"] for ii in dlist] - all_nsel = [ii["sel"] for ii in dlist] + all_rcut = sorted([ii["rcut"] for ii in dlist]) + all_nsel = sorted([ii["sel"] for ii in dlist]) rcut_max = max(all_rcut) + sel_max = max(all_nsel) coord = self.coord atype = self.atype box = self.cell - nf, nloc = coord.shape[:2] - coord_normalized = normalize_coord(coord, box.reshape(-1, 3, 3)) - extended_coord, extended_atype, mapping = extend_coord_with_ghosts( - coord_normalized, atype, box, rcut_max - ) - ## single nlist - # nlist = build_neighbor_list( - # extended_coord, extended_atype, nloc, - # rcut_max, nsel, distinguish_types=False) - nlist_list = [] - for rcut, sel in zip(all_rcut, all_nsel): - nlist_list.append( - build_neighbor_list( - extended_coord, - extended_atype, - nloc, - rcut, - sel, - distinguish_types=False, - ) - ) - nlist = torch.cat(nlist_list, -1) # handel type_embedding type_embedding = TypeEmbedNet(ntypes, 8).to(env.DEVICE) type_embedding.load_state_dict(torch.load(self.file_type_embed)) @@ -164,6 +142,31 @@ def test_descriptor_hyb(self): ## to save model parameters # torch.save(des.state_dict(), 'model_weights.pth') # torch.save(type_embedding.state_dict(), 'model_weights.pth') + ( + extended_coord, + extended_atype, + mapping, + nlist_max, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + rcut_max, + sel_max, + distinguish_types=des.distinguish_types(), + box=box, + ) + nlist_dict = build_multiple_neighbor_list( + extended_coord, + nlist_max, + all_rcut, + all_nsel, + ) + nlist_list = [] + for ii in des.descriptor_list: + nlist_list.append( + nlist_dict[get_multiple_nlist_key(ii.get_rcut(), ii.get_nsel())] + ) + nlist = torch.cat(nlist_list, -1) descriptor, env_mat, diff, rot_mat, sw = des( nlist, extended_coord, @@ -202,18 +205,18 @@ def test_descriptor(self): coord = self.coord atype = self.atype box = self.cell - nf, nloc = coord.shape[:2] - coord_normalized = normalize_coord(coord, box.reshape(-1, 3, 3)) - extended_coord, extended_atype, mapping = extend_coord_with_ghosts( - coord_normalized, atype, box, des.repinit.rcut - ) - nlist = build_neighbor_list( + ( extended_coord, extended_atype, - nloc, - des.repinit.rcut, - des.repinit.sel, - distinguish_types=False, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + des.get_rcut(), + des.get_sel(), + distinguish_types=des.distinguish_types(), + box=box, ) descriptor, env_mat, diff, rot_mat, sw = des( extended_coord, diff --git a/source/tests/pt/model/test_dp_model.py b/source/tests/pt/model/test_dp_model.py index f3f899fbe2..d970c8a542 100644 --- a/source/tests/pt/model/test_dp_model.py +++ b/source/tests/pt/model/test_dp_model.py @@ -23,6 +23,7 @@ from deepmd.pt.utils.nlist import ( build_neighbor_list, extend_coord_with_ghosts, + extend_input_and_build_neighbor_list, ) from deepmd.pt.utils.utils import ( to_numpy_array, @@ -433,20 +434,13 @@ def test_self_consistency(self): to_numpy_array(ret0["atom_virial"]), to_numpy_array(ret1["atom_virial"]), ) - - coord_ext, atype_ext, mapping = extend_coord_with_ghosts( + coord_ext, atype_ext, mapping, nlist = extend_input_and_build_neighbor_list( to_torch_tensor(self.coord), to_torch_tensor(self.atype), - to_torch_tensor(self.cell), - self.rcut, - ) - nlist = build_neighbor_list( - coord_ext, - atype_ext, - self.nloc, self.rcut, self.sel, distinguish_types=md0.distinguish_types(), + box=to_torch_tensor(self.cell), ) args = [coord_ext, atype_ext, nlist] ret2 = md0.forward_lower(*args, do_atomic_virial=True) diff --git a/source/tests/pt/model/test_embedding_net.py b/source/tests/pt/model/test_embedding_net.py index 407f4949b5..2621b5d135 100644 --- a/source/tests/pt/model/test_embedding_net.py +++ b/source/tests/pt/model/test_embedding_net.py @@ -25,12 +25,15 @@ dp_random, ) from deepmd.pt.utils.dataset import ( - DeepmdDataSet, + DeepmdDataSetForLoader, ) from deepmd.pt.utils.env import ( DEVICE, GLOBAL_NP_FLOAT_PRECISION, ) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) from deepmd.tf.common import ( expand_sys_str, ) @@ -43,6 +46,25 @@ def gen_key(worb, depth, elemid): return (worb, depth, elemid) +def get_single_batch(dataset, index=None): + if index is None: + index = dp_random.choice(np.arange(len(dataset))) + pt_batch = dataset[index] + np_batch = {} + # TODO deprecated + for key in ["mapping", "shift", "nlist"]: + if key in pt_batch.keys(): + pt_batch[key] = pt_batch[key].unsqueeze(0) + for key in ["coord", "box", "force", "energy", "virial", "atype", "natoms"]: + if key in pt_batch.keys(): + pt_batch[key] = pt_batch[key].unsqueeze(0) + np_batch[key] = pt_batch[key].cpu().numpy() + np_batch["coord"] = np_batch["coord"].reshape(1, -1) + np_batch["natoms"] = np_batch["natoms"][0] + np_batch["force"] = np_batch["force"].reshape(1, -1) + return np_batch, pt_batch + + def base_se_a(descriptor, coord, atype, natoms, box): g = tf.Graph() with g.as_default(): @@ -105,12 +127,16 @@ def setUp(self): self.systems = config["training"]["validation_data"]["systems"] if isinstance(self.systems, str): self.systems = expand_sys_str(self.systems) - ds = DeepmdDataSet( - self.systems, self.bsz, model_config["type_map"], self.rcut, self.sel + ds = DeepmdDataSetForLoader( + self.systems[0], + model_config["type_map"], + self.rcut, + self.sel, + type_split=True, ) self.filter_neuron = model_config["descriptor"]["neuron"] self.axis_neuron = model_config["descriptor"]["axis_neuron"] - self.np_batch, self.torch_batch = ds.get_batch() + self.np_batch, self.torch_batch = get_single_batch(ds) def test_consistency(self): dp_d = DescrptSeA_tf( @@ -154,20 +180,23 @@ def test_consistency(self): pt_coord = self.torch_batch["coord"].to(env.DEVICE) pt_coord.requires_grad_(True) - index = ( - self.torch_batch["mapping"].unsqueeze(-1).expand(-1, -1, 3).to(env.DEVICE) - ) - extended_coord = torch.gather(pt_coord, dim=1, index=index) - extended_coord = extended_coord - self.torch_batch["shift"].to(env.DEVICE) - extended_atype = torch.gather( + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + pt_coord, self.torch_batch["atype"].to(env.DEVICE), - dim=1, - index=self.torch_batch["mapping"].to(env.DEVICE), + self.rcut, + self.sel, + distinguish_types=True, + box=self.torch_batch["box"].to(env.DEVICE), ) descriptor_out, _, _, _, _ = descriptor( extended_coord, extended_atype, - self.torch_batch["nlist"].to(env.DEVICE), + nlist, ) my_embedding = descriptor_out.cpu().detach().numpy() fake_energy = torch.sum(descriptor_out) diff --git a/source/tests/pt/model/test_model.py b/source/tests/pt/model/test_model.py index 522b30b2df..efe013a8a1 100644 --- a/source/tests/pt/model/test_model.py +++ b/source/tests/pt/model/test_model.py @@ -331,7 +331,9 @@ def test_consistency(self): # print(dst.mean(), dst.std()) dst.copy_(src) # Start forward computing - batch = my_ds.systems[0]._data_system.preprocess(batch) + batch = my_ds.systems[0]._data_system.single_preprocess(batch, 0) + for key in ["coord", "atype", "box", "energy", "force"]: + batch[key] = batch[key].unsqueeze(0) batch["coord"].requires_grad_(True) batch["natoms"] = torch.tensor( batch["natoms_vec"], device=batch["coord"].device diff --git a/source/tests/pt/test_loss.py b/source/tests/pt/test_loss.py index 14934c7be0..f0c75ef288 100644 --- a/source/tests/pt/test_loss.py +++ b/source/tests/pt/test_loss.py @@ -16,7 +16,7 @@ EnergyStdLoss, ) from deepmd.pt.utils.dataset import ( - DeepmdDataSet, + DeepmdDataSetForLoader, ) from deepmd.tf.common import ( expand_sys_str, @@ -25,6 +25,10 @@ EnerStdLoss, ) +from .model.test_embedding_net import ( + get_single_batch, +) + CUR_DIR = os.path.dirname(__file__) @@ -39,12 +43,13 @@ def get_batch(): rcut = model_config["descriptor"]["rcut"] # self.rcut_smth = model_config['descriptor']['rcut_smth'] sel = model_config["descriptor"]["sel"] - batch_size = config["training"]["training_data"]["batch_size"] systems = config["training"]["validation_data"]["systems"] if isinstance(systems, str): systems = expand_sys_str(systems) - dataset = DeepmdDataSet(systems, batch_size, model_config["type_map"], rcut, sel) - np_batch, pt_batch = dataset.get_batch() + dataset = DeepmdDataSetForLoader( + systems[0], model_config["type_map"], rcut, sel, type_split=True + ) + np_batch, pt_batch = get_single_batch(dataset) return np_batch, pt_batch diff --git a/source/tests/pt/test_stat.py b/source/tests/pt/test_stat.py index 240c354a69..bc95575a5a 100644 --- a/source/tests/pt/test_stat.py +++ b/source/tests/pt/test_stat.py @@ -165,10 +165,7 @@ def test_descriptor(self): "energy", "atype", "natoms", - "extended_coord", - "nlist", - "shift", - "mapping", + "box", ]: if key in sys.keys(): sys[key] = sys[key].to(env.DEVICE)