From 37bee250e3167ef2489dcb18ccae319b9bbd3cee Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 12 Jul 2024 17:56:20 +0800 Subject: [PATCH] fix(pt): add `finetune_head` to argcheck (#3967) Add `finetune_head` to argcheck. ## Summary by CodeRabbit - **New Features** - Introduced a new `finetune_head` argument for specifying the fitting net during multi-task fine-tuning, with optional random initialization if not set. - **Bug Fixes** - Improved handling for specific conditions by automatically removing the "finetune_head" key from the configuration. - **Tests** - Updated multitask training and finetuning tests to include new configuration manipulations. - Removed the `_comment` field from test configuration files to ensure cleaner test setups. --- deepmd/utils/argcheck.py | 10 ++++++++++ source/tests/pt/model/water/multitask.json | 1 - source/tests/pt/test_multitask.py | 10 ++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 0bf50fd019..acb243ea2f 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1533,6 +1533,10 @@ def model_args(exclude_hybrid=False): doc_spin = "The settings for systems with spin." doc_atom_exclude_types = "Exclude the atomic contribution of the listed atom types" doc_pair_exclude_types = "The atom pairs of the listed types are not treated to be neighbors, i.e. they do not see each other." + doc_finetune_head = ( + "The chosen fitting net to fine-tune on, when doing multi-task fine-tuning. " + "If not set or set to 'RANDOM', the fitting net will be randomly initialized." + ) hybrid_models = [] if not exclude_hybrid: @@ -1629,6 +1633,12 @@ def model_args(exclude_hybrid=False): fold_subdoc=True, ), Argument("spin", dict, spin_args(), [], optional=True, doc=doc_spin), + Argument( + "finetune_head", + str, + optional=True, + doc=doc_only_pt_supported + doc_finetune_head, + ), ], [ Variant( diff --git a/source/tests/pt/model/water/multitask.json b/source/tests/pt/model/water/multitask.json index c59618145d..06a4f88e55 100644 --- a/source/tests/pt/model/water/multitask.json +++ b/source/tests/pt/model/water/multitask.json @@ -68,7 +68,6 @@ "_comment": "that's all" }, "loss_dict": { - "_comment": " that's all", "model_1": { "type": "ener", "start_pref_e": 0.02, diff --git a/source/tests/pt/test_multitask.py b/source/tests/pt/test_multitask.py index cf9ec9685d..d0a647dbbf 100644 --- a/source/tests/pt/test_multitask.py +++ b/source/tests/pt/test_multitask.py @@ -21,6 +21,12 @@ from deepmd.pt.utils.multi_task import ( preprocess_shared_params, ) +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) from .model.test_permutation import ( model_dpa1, @@ -39,6 +45,8 @@ def setUpModule(): class MultiTaskTrainTest: def test_multitask_train(self): # test multitask training + self.config = update_deepmd_input(self.config, warning=True) + self.config = normalize(self.config, multi_task=True) trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links) trainer.run() # check model keys @@ -124,6 +132,8 @@ def test_multitask_train(self): finetune_model, self.origin_config["model"], ) + self.origin_config = update_deepmd_input(self.origin_config, warning=True) + self.origin_config = normalize(self.origin_config, multi_task=True) trainer_finetune = get_trainer( deepcopy(self.origin_config), finetune_model=finetune_model,