Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/add dos finetune UT #3876

Merged
merged 10 commits into from
Jun 17, 2024
17 changes: 7 additions & 10 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
TensorLoss,
)
from deepmd.pt.model.model import (
DOSModel,
get_model,
get_zbl_model,
)
Expand Down Expand Up @@ -601,15 +600,13 @@
_finetune_rule_single,
_sample_func,
):
# need fix for DOSModel
if not isinstance(_model, DOSModel):
_model = _model_change_out_bias(
_model,
_sample_func,
_bias_adjust_mode="change-by-statistic"
if not _finetune_rule_single.get_random_fitting()
else "set-by-statistic",
)
_model = _model_change_out_bias(

Check warning on line 603 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L603

Added line #L603 was not covered by tests
_model,
_sample_func,
_bias_adjust_mode="change-by-statistic"
if not _finetune_rule_single.get_random_fitting()
else "set-by-statistic",
)
return _model

if not self.multi_task:
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,10 @@
model_pred: Optional[Dict[str, np.ndarray]] = None,
):
"""This function only handle stat computation from reduced global labels."""
# return directly if model predict is empty for global
if not model_pred:
return {}, {}

Check warning on line 404 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L403-L404

Added lines #L403 - L404 were not covered by tests

# get label dict from sample; for each key, only picking the system with global labels.
outputs = {
kk: [
Expand Down
36 changes: 32 additions & 4 deletions source/tests/pt/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)

from .model.test_permutation import (
model_dos,
model_dpa1,
model_dpa2,
model_se_e2_a,
Expand Down Expand Up @@ -72,6 +73,13 @@
must=False,
high_prec=False,
),
DataRequirementItem(
"dos",
ndof=250,
atomic=False,
must=False,
high_prec=True,
),
DataRequirementItem(
"atom_ener",
ndof=1,
Expand All @@ -92,6 +100,7 @@

class FinetuneTest:
def test_finetune_change_out_bias(self):
self.testkey = "energy" if self.testkey is None else self.testkey
# get data
data = DpLoaderSet(
self.data_file,
Expand All @@ -108,7 +117,7 @@ def test_finetune_change_out_bias(self):
model = get_model(self.config["model"]).to(env.DEVICE)
atomic_model = model.atomic_model
atomic_model["out_bias"] = torch.rand_like(atomic_model["out_bias"])
energy_bias_before = to_numpy_array(atomic_model["out_bias"])[0].ravel()
energy_bias_before = to_numpy_array(atomic_model["out_bias"])[0]

# prepare original model for test
dp = torch.jit.script(model)
Expand All @@ -123,7 +132,7 @@ def test_finetune_change_out_bias(self):
sampled,
bias_adjust_mode="change-by-statistic",
)
energy_bias_after = to_numpy_array(atomic_model["out_bias"])[0].ravel()
energy_bias_after = to_numpy_array(atomic_model["out_bias"])[0]

# get ground-truth energy bias change
sorter = np.argsort(full_type_map)
Expand All @@ -140,10 +149,10 @@ def test_finetune_change_out_bias(self):
to_numpy_array(sampled[0]["box"][:ntest]),
to_numpy_array(sampled[0]["atype"][0]),
)[0]
energy_diff = to_numpy_array(sampled[0]["energy"][:ntest]) - energy
energy_diff = to_numpy_array(sampled[0][self.testkey][:ntest]) - energy
finetune_shift = (
energy_bias_after[idx_type_map] - energy_bias_before[idx_type_map]
)
).ravel()
anyangml marked this conversation as resolved.
Show resolved Hide resolved
ground_truth_shift = np.linalg.lstsq(atom_nums, energy_diff, rcond=None)[
0
].reshape(-1)
Expand Down Expand Up @@ -262,6 +271,7 @@ def setUp(self):
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
self.mixed_types = False
self.testkey = None


class TestEnergyZBLModelSeA(FinetuneTest, unittest.TestCase):
Expand All @@ -276,6 +286,22 @@ def setUp(self):
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
self.mixed_types = False
self.testkey = None


class TestEnergyDOSModelSeA(FinetuneTest, unittest.TestCase):
def setUp(self):
input_json = str(Path(__file__).parent / "dos/input.json")
with open(input_json) as f:
self.config = json.load(f)
self.data_file = [str(Path(__file__).parent / "dos/data/global_system")]
self.config["training"]["training_data"]["systems"] = self.data_file
self.config["training"]["validation_data"]["systems"] = self.data_file
self.config["model"] = deepcopy(model_dos)
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
self.mixed_types = False
self.testkey = "dos"
anyangml marked this conversation as resolved.
Show resolved Hide resolved


class TestEnergyModelDPA1(FinetuneTest, unittest.TestCase):
Expand All @@ -290,6 +316,7 @@ def setUp(self):
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
self.mixed_types = True
self.testkey = None


class TestEnergyModelDPA2(FinetuneTest, unittest.TestCase):
Expand All @@ -306,6 +333,7 @@ def setUp(self):
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
self.mixed_types = True
self.testkey = None


if __name__ == "__main__":
Expand Down
Loading