Skip to content

Commit

Permalink
breaking: pt: remove data stat from model init (deepmodeling#3245)
Browse files Browse the repository at this point in the history
Restore deepmodeling#3233 with resolved conflicts and conversations.

This PR clean up the data stat process from model init.

Please note that this code PR is just an initial cleanup and refinement
of the data stat, and a more detailed design of the data stat will be
completed in the next PR:

- independent data stat from dataloader
- data stat support for hybrid descriptors

---------

Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and njzjz committed Feb 10, 2024
1 parent 7802def commit 9f170e0
Show file tree
Hide file tree
Showing 41 changed files with 725 additions and 526 deletions.
8 changes: 3 additions & 5 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,13 @@ def distinguish_types(self) -> bool:
"""
pass

@abstractmethod
def compute_input_stats(self, merged):
"""Update mean and stddev for descriptor elements."""
pass
raise NotImplementedError

@abstractmethod
def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):
def init_desc_stat(self, **kwargs):
"""Initialize the model bias by the statistics."""
pass
raise NotImplementedError

@abstractmethod
def fwd(
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,14 @@ def __getitem__(self, key):
else:
raise KeyError(key)

def compute_output_stats(self, merged):
"""Update the output bias for fitting net."""
raise NotImplementedError

def init_fitting_stat(self, result_dict):
"""Initialize the model bias by the statistics."""
raise NotImplementedError

def serialize(self) -> dict:
"""Serialize the fitting to dict."""
return {
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/fitting/make_base_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def fwd(
"""Calculate fitting."""
pass

def compute_output_stats(self, merged):
"""Update the output bias for fitting net."""
raise NotImplementedError

def init_fitting_stat(self, **kwargs):
"""Initialize the model bias by the statistics."""
raise NotImplementedError

@abstractmethod
def serialize(self) -> dict:
"""Serialize the obj to dict."""
Expand Down
75 changes: 29 additions & 46 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
from deepmd.pt.model.descriptor import (
Descriptor,
)
from deepmd.pt.model.task import (
Fitting,
)
from deepmd.pt.train import (
training,
)
Expand All @@ -63,6 +66,7 @@
)
from deepmd.pt.utils.stat import (
make_stat_input,
process_stat_path,
)
from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter

Expand Down Expand Up @@ -128,51 +132,18 @@ def prepare_trainer_input_single(

# stat files
hybrid_descrpt = model_params_single["descriptor"]["type"] == "hybrid"
has_stat_file_path = True
if not hybrid_descrpt:
### this design requires "rcut", "rcut_smth" and "sel" in the descriptor
### VERY BAD DESIGN!!!!
### not all descriptors provides these parameter in their constructor
default_stat_file_name = Descriptor.get_stat_name(
model_params_single["descriptor"]
)
model_params_single["stat_file_dir"] = data_dict_single.get(
"stat_file_dir", f"stat_files{suffix}"
)
model_params_single["stat_file"] = data_dict_single.get(
"stat_file", default_stat_file_name
)
model_params_single["stat_file_path"] = os.path.join(
model_params_single["stat_file_dir"], model_params_single["stat_file"]
)
if not os.path.exists(model_params_single["stat_file_path"]):
has_stat_file_path = False
else: ### need to remove this
default_stat_file_name = []
for descrpt in model_params_single["descriptor"]["list"]:
default_stat_file_name.append(
f'stat_file_rcut{descrpt["rcut"]:.2f}_'
f'smth{descrpt["rcut_smth"]:.2f}_'
f'sel{descrpt["sel"]}_{descrpt["type"]}.npz'
)
model_params_single["stat_file_dir"] = data_dict_single.get(
"stat_file_dir", f"stat_files{suffix}"
stat_file_path_single, has_stat_file_path = process_stat_path(
data_dict_single.get("stat_file", None),
data_dict_single.get("stat_file_dir", f"stat_files{suffix}"),
model_params_single,
Descriptor,
Fitting,
)
model_params_single["stat_file"] = data_dict_single.get(
"stat_file", default_stat_file_name
else: ### TODO hybrid descriptor not implemented
raise NotImplementedError(
"data stat for hybrid descriptor is not implemented!"
)
assert isinstance(
model_params_single["stat_file"], list
), "Stat file of hybrid descriptor must be a list!"
stat_file_path = []
for stat_file_path_item in model_params_single["stat_file"]:
single_file_path = os.path.join(
model_params_single["stat_file_dir"], stat_file_path_item
)
stat_file_path.append(single_file_path)
if not os.path.exists(single_file_path):
has_stat_file_path = False
model_params_single["stat_file_path"] = stat_file_path

# validation and training data
validation_data_single = DpLoaderSet(
Expand Down Expand Up @@ -212,19 +183,30 @@ def prepare_trainer_input_single(
type_split=type_split,
noise_settings=noise_settings,
)
return train_data_single, validation_data_single, sampled_single
return (
train_data_single,
validation_data_single,
sampled_single,
stat_file_path_single,
)

if not multi_task:
train_data, validation_data, sampled = prepare_trainer_input_single(
(
train_data,
validation_data,
sampled,
stat_file_path,
) = prepare_trainer_input_single(
config["model"], config["training"], config["loss"]
)
else:
train_data, validation_data, sampled = {}, {}, {}
train_data, validation_data, sampled, stat_file_path = {}, {}, {}, {}
for model_key in config["model"]["model_dict"]:
(
train_data[model_key],
validation_data[model_key],
sampled[model_key],
stat_file_path[model_key],
) = prepare_trainer_input_single(
config["model"]["model_dict"][model_key],
config["training"]["data_dict"][model_key],
Expand All @@ -235,7 +217,8 @@ def prepare_trainer_input_single(
trainer = training.Trainer(
config,
train_data,
sampled,
sampled=sampled,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
restart_model=restart_model,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
self.input_param["resuming"] = True
self.multi_task = "model_dict" in self.input_param
assert not self.multi_task, "multitask mode currently not supported!"
model = get_model(self.input_param, None).to(DEVICE)
model = get_model(self.input_param).to(DEVICE)
model = torch.jit.script(model)
self.dp = ModelWrapper(model)
self.dp.load_state_dict(state_dict)
Expand Down
150 changes: 134 additions & 16 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from abc import (
ABC,
abstractmethod,
Expand All @@ -7,6 +8,7 @@
Callable,
List,
Optional,
Union,
)

import numpy as np
Expand All @@ -23,6 +25,8 @@
BaseDescriptor,
)

log = logging.getLogger(__name__)


class Descriptor(torch.nn.Module, BaseDescriptor):
"""The descriptor.
Expand Down Expand Up @@ -56,15 +60,130 @@ class SomeDescript(Descriptor):
return Descriptor.__plugins.register(key)

@classmethod
def get_stat_name(cls, config):
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(config)
def get_stat_name(cls, ntypes, type_name, **kwargs):
"""
Get the name for the statistic file of the descriptor.
Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name.
"""
if cls is not Descriptor:
raise NotImplementedError("get_stat_name is not implemented!")
descrpt_type = type_name
return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(
ntypes, type_name, **kwargs
)

@classmethod
def get_data_process_key(cls, config):
"""
Get the keys for the data preprocess.
Usually need the information of rcut and sel.
TODO Need to be deprecated when the dataloader has been cleaned up.
"""
if cls is not Descriptor:
raise NotImplementedError("get_data_process_key is not implemented!")
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_process_key(config)

@property
def data_stat_key(self):
"""
Get the keys for the data statistic of the descriptor.
Return a list of statistic names needed, such as "sumr", "suma" or "sumn".
"""
raise NotImplementedError("data_stat_key is not implemented!")

def compute_or_load_stat(
self,
type_map: List[str],
sampled=None,
stat_file_path: Optional[Union[str, List[str]]] = None,
):
"""
Compute or load the statistics parameters of the descriptor.
Calculate and save the mean and standard deviation of the descriptor to `stat_file_path`
if `sampled` is not None, otherwise load them from `stat_file_path`.
Parameters
----------
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
sampled
The sampled data frames from different data systems.
stat_file_path
The path to the statistics files.
"""
# TODO support hybrid descriptor
descrpt_stat_key = self.data_stat_key
if sampled is not None: # compute the statistics results
tmp_dict = self.compute_input_stats(sampled)
result_dict = {key: tmp_dict[key] for key in descrpt_stat_key}
result_dict["type_map"] = type_map
if stat_file_path is not None:
self.save_stats(result_dict, stat_file_path)
else: # load the statistics results
assert stat_file_path is not None, "No stat file to load!"
result_dict = self.load_stats(type_map, stat_file_path)
self.init_desc_stat(**result_dict)

def save_stats(self, result_dict, stat_file_path: Union[str, List[str]]):
"""
Save the statistics results to `stat_file_path`.
Parameters
----------
result_dict
The dictionary of statistics results.
stat_file_path
The path to the statistics file(s).
"""
if not isinstance(stat_file_path, list):
log.info(f"Saving stat file to {stat_file_path}")
np.savez_compressed(stat_file_path, **result_dict)
else: # TODO hybrid descriptor not implemented
raise NotImplementedError(
"save_stats for hybrid descriptor is not implemented!"
)

def load_stats(self, type_map, stat_file_path: Union[str, List[str]]):
"""
Load the statistics results to `stat_file_path`.
Parameters
----------
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
stat_file_path
The path to the statistics file(s).
Returns
-------
result_dict
The dictionary of statistics results.
"""
descrpt_stat_key = self.data_stat_key
target_type_map = type_map
if not isinstance(stat_file_path, list):
log.info(f"Loading stat file from {stat_file_path}")
stats = np.load(stat_file_path)
stat_type_map = list(stats["type_map"])
missing_type = [i for i in target_type_map if i not in stat_type_map]
assert not missing_type, (
f"These type are not in stat file {stat_file_path}: {missing_type}! "
f"Please change the stat file path!"
)
idx_map = [stat_type_map.index(i) for i in target_type_map]
if stats[descrpt_stat_key[0]].size: # not empty
result_dict = {key: stats[key][idx_map] for key in descrpt_stat_key}
else:
result_dict = {key: [] for key in descrpt_stat_key}
else: # TODO hybrid descriptor not implemented
raise NotImplementedError(
"load_stats for hybrid descriptor is not implemented!"
)
return result_dict

def __new__(cls, *args, **kwargs):
if cls is Descriptor:
try:
Expand Down Expand Up @@ -156,15 +275,13 @@ def get_dim_emb(self) -> int:
"""Returns the embedding dimension."""
pass

@abstractmethod
def compute_input_stats(self, merged):
"""Update mean and stddev for DescriptorBlock elements."""
pass
raise NotImplementedError

@abstractmethod
def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):
"""Initialize the model bias by the statistics."""
pass
def init_desc_stat(self, **kwargs):
"""Initialize mean and stddev by the statistics."""
raise NotImplementedError

def share_params(self, base_class, shared_level, resume=False):
assert (
Expand All @@ -188,13 +305,14 @@ def share_params(self, base_class, shared_level, resume=False):
self.sumr2,
self.suma2,
)
base_class.init_desc_stat(
sumr_base + sumr,
suma_base + suma,
sumn_base + sumn,
sumr2_base + sumr2,
suma2_base + suma2,
)
stat_dict = {
"sumr": sumr_base + sumr,
"suma": suma_base + suma,
"sumn": sumn_base + sumn,
"sumr2": sumr2_base + sumr2,
"suma2": suma2_base + suma2,
}
base_class.init_desc_stat(**stat_dict)
self.mean = base_class.mean
self.stddev = base_class.stddev
# self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model
Expand Down
Loading

0 comments on commit 9f170e0

Please sign in to comment.