Skip to content

Commit

Permalink
fix(pt): add finetune_head to argcheck (deepmodeling#3967)
Browse files Browse the repository at this point in the history
Add `finetune_head`  to argcheck.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a new `finetune_head` argument for specifying the fitting
net during multi-task fine-tuning, with optional random initialization
if not set.
  
- **Bug Fixes**
- Improved handling for specific conditions by automatically removing
the "finetune_head" key from the configuration.

- **Tests**
- Updated multitask training and finetuning tests to include new
configuration manipulations.
- Removed the `_comment` field from test configuration files to ensure
cleaner test setups.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored Jul 12, 2024
1 parent 698ff6f commit 37bee25
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
10 changes: 10 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,10 @@ def model_args(exclude_hybrid=False):
doc_spin = "The settings for systems with spin."
doc_atom_exclude_types = "Exclude the atomic contribution of the listed atom types"
doc_pair_exclude_types = "The atom pairs of the listed types are not treated to be neighbors, i.e. they do not see each other."
doc_finetune_head = (
"The chosen fitting net to fine-tune on, when doing multi-task fine-tuning. "
"If not set or set to 'RANDOM', the fitting net will be randomly initialized."
)

hybrid_models = []
if not exclude_hybrid:
Expand Down Expand Up @@ -1629,6 +1633,12 @@ def model_args(exclude_hybrid=False):
fold_subdoc=True,
),
Argument("spin", dict, spin_args(), [], optional=True, doc=doc_spin),
Argument(
"finetune_head",
str,
optional=True,
doc=doc_only_pt_supported + doc_finetune_head,
),
],
[
Variant(
Expand Down
1 change: 0 additions & 1 deletion source/tests/pt/model/water/multitask.json
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
"_comment": "that's all"
},
"loss_dict": {
"_comment": " that's all",
"model_1": {
"type": "ener",
"start_pref_e": 0.02,
Expand Down
10 changes: 10 additions & 0 deletions source/tests/pt/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.utils.argcheck import (
normalize,
)
from deepmd.utils.compat import (
update_deepmd_input,
)

from .model.test_permutation import (
model_dpa1,
Expand All @@ -39,6 +45,8 @@ def setUpModule():
class MultiTaskTrainTest:
def test_multitask_train(self):
# test multitask training
self.config = update_deepmd_input(self.config, warning=True)
self.config = normalize(self.config, multi_task=True)
trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links)
trainer.run()
# check model keys
Expand Down Expand Up @@ -124,6 +132,8 @@ def test_multitask_train(self):
finetune_model,
self.origin_config["model"],
)
self.origin_config = update_deepmd_input(self.origin_config, warning=True)
self.origin_config = normalize(self.origin_config, multi_task=True)
trainer_finetune = get_trainer(
deepcopy(self.origin_config),
finetune_model=finetune_model,
Expand Down

0 comments on commit 37bee25

Please sign in to comment.