diff --git a/deepmd/main.py b/deepmd/main.py index 4d2d62ed14..df5c99bb2d 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -226,7 +226,7 @@ def main_parser() -> argparse.ArgumentParser: "--init-frz-model", type=str, default=None, - help="(Supported backend: TensorFlow) Initialize the training from the frozen model.", + help="Initialize the training from the frozen model.", ) parser_train_subgroup.add_argument( "-t", diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 212a6824e7..a317cea6a9 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -65,6 +65,7 @@ def get_trainer( finetune_model=None, model_branch="", force_load=False, + init_frz_model=None, ): # Initialize DDP local_rank = os.environ.get("LOCAL_RANK") @@ -200,6 +201,7 @@ def prepare_trainer_input_single( finetune_model=finetune_model, force_load=force_load, shared_links=shared_links, + init_frz_model=init_frz_model, ) return trainer @@ -243,6 +245,7 @@ def train(FLAGS): FLAGS.finetune, FLAGS.model_branch, FLAGS.force_load, + FLAGS.init_frz_model, ) trainer.run() diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 5a783e412b..152c69a444 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -75,6 +75,7 @@ def __init__( finetune_model=None, force_load=False, shared_links=None, + init_frz_model=None, ): """Construct a DeePMD trainer. @@ -271,7 +272,7 @@ def get_loss(loss_params, start_lr, _ntypes): self.warmup_steps = training_params.get("warmup_steps", 0) self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0) assert ( - self.num_steps - self.warmup_steps > 0 + self.num_steps - self.warmup_steps > 0 or self.warmup_steps == 0 ), "Warm up steps must be less than total training steps!" if self.multi_task and config.get("learning_rate_dict", None) is not None: self.lr_exp = {} @@ -394,6 +395,9 @@ def get_loss(loss_params, start_lr, _ntypes): ntest=ntest, bias_shift=model_params.get("bias_shift", "delta"), ) + if init_frz_model is not None: + frz_model = torch.jit.load(init_frz_model, map_location=DEVICE) + self.model.load_state_dict(frz_model.state_dict()) # Set trainable params self.wrapper.set_trainable_params() @@ -724,6 +728,15 @@ def log_loss_valid(_task_key="Default"): if ( self.rank == 0 or dist.get_rank() == 0 ): # Handle the case if rank 0 aborted and re-assigned + if self.num_steps == 0: + # when num_steps is 0, the checkpoint is never not saved + self.latest_model = Path(self.save_ckpt + "-0.pt") + self.save_model(self.latest_model, lr=0, step=0) + log.info(f"Saved model to {self.latest_model}") + symlink_prefix_files(self.latest_model.stem, self.save_ckpt) + with open("checkpoint", "w") as f: + f.write(str(self.latest_model)) + if JIT: pth_model_path = ( "frozen_model.pth" # We use .pth to denote the frozen model @@ -759,9 +772,10 @@ def get_data(self, is_train=True, task_key="Default"): batch_data = next(iter(self.training_data)) except StopIteration: # Refresh the status of the dataloader to start from a new epoch - self.training_data = BufferedIterator( - iter(self.training_dataloader) - ) + with torch.device("cpu"): + self.training_data = BufferedIterator( + iter(self.training_dataloader) + ) batch_data = next(iter(self.training_data)) else: try: diff --git a/source/tests/pt/test_init_frz_model.py b/source/tests/pt/test_init_frz_model.py new file mode 100644 index 0000000000..d156eddc41 --- /dev/null +++ b/source/tests/pt/test_init_frz_model.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import unittest +from argparse import ( + Namespace, +) +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) + +import numpy as np + +from deepmd.pt.entrypoints.main import ( + freeze, + get_trainer, +) +from deepmd.pt.infer.deep_eval import ( + DeepPot, +) + + +class TestInitFrzModel(unittest.TestCase): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + config = json.load(f) + config["training"]["numb_steps"] = 1 + config["training"]["save_freq"] = 1 + config["learning_rate"]["start_lr"] = 1.0 + config["training"]["training_data"]["systems"] = [ + str(Path(__file__).parent / "water/data/single") + ] + config["training"]["validation_data"]["systems"] = [ + str(Path(__file__).parent / "water/data/single") + ] + + self.models = [] + for imodel in range(2): + if imodel == 1: + config["training"]["numb_steps"] = 0 + trainer = get_trainer(deepcopy(config), init_frz_model=self.models[-1]) + else: + trainer = get_trainer(deepcopy(config)) + trainer.run() + + frozen_model = f"frozen_model{imodel}.pth" + ns = Namespace( + model="model.pt", + output=frozen_model, + head=None, + ) + freeze(ns) + self.models.append(frozen_model) + + def test_dp_test(self): + dp1 = DeepPot(str(self.models[0])) + dp2 = DeepPot(str(self.models[1])) + cell = np.array( + [ + 5.122106549439247480e00, + 4.016537340154059388e-01, + 6.951654033828678081e-01, + 4.016537340154059388e-01, + 6.112136112297989143e00, + 8.178091365465004481e-01, + 6.951654033828678081e-01, + 8.178091365465004481e-01, + 6.159552512682983760e00, + ] + ).reshape(1, 3, 3) + coord = np.array( + [ + 2.978060152121375648e00, + 3.588469695887098077e00, + 2.792459820604495491e00, + 3.895592322591093115e00, + 2.712091020667753760e00, + 1.366836847133650501e00, + 9.955616170888935690e-01, + 4.121324820711413039e00, + 1.817239061889086571e00, + 3.553661462345699906e00, + 5.313046969500791583e00, + 6.635182659098815883e00, + 6.088601018589653080e00, + 6.575011420004332585e00, + 6.825240650611076099e00, + ] + ).reshape(1, -1, 3) + atype = np.array([0, 0, 0, 1, 1]).reshape(1, -1) + + e1, f1, v1, ae1, av1 = dp1.eval(coord, cell, atype, atomic=True) + e2, f2, v2, ae2, av2 = dp2.eval(coord, cell, atype, atomic=True) + np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(ae1, ae2, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(av1, av2, rtol=1e-10, atol=1e-10)