diff --git a/examples/Training a M3GNet Potential with PyTorch Lightning.ipynb b/examples/Training a M3GNet Potential with PyTorch Lightning.ipynb index 1f899118..b1e6175c 100644 --- a/examples/Training a M3GNet Potential with PyTorch Lightning.ipynb +++ b/examples/Training a M3GNet Potential with PyTorch Lightning.ipynb @@ -268,7 +268,8 @@ "source": [ "# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator=\"cpu\" kwarg.\n", "logger = CSVLogger(\"logs\", name=\"M3GNet_training\")\n", - "trainer = pl.Trainer(max_epochs=1, accelerator=\"cpu\", logger=logger)\n", + "# Inference mode = False is required for calculating forces, stress in test mode and prediction mode\n", + "trainer = pl.Trainer(max_epochs=1, accelerator=\"cpu\", logger=logger, inference_mode=False)\n", "trainer.fit(model=lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader)" ] }, @@ -405,7 +406,7 @@ "source": [ "# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator=\"cpu\" kwarg.\n", "logger = CSVLogger(\"logs\", name=\"M3GNet_finetuning\")\n", - "trainer = pl.Trainer(max_epochs=1, accelerator=\"cpu\", logger=logger)\n", + "trainer = pl.Trainer(max_epochs=1, accelerator=\"cpu\", logger=logger, inference_mode=False)\n", "trainer.fit(model=lit_module_finetune, train_dataloaders=train_loader, val_dataloaders=val_loader)" ] }, @@ -467,7 +468,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.10.9" } }, "nbformat": 4, diff --git a/tests/utils/test_training.py b/tests/utils/test_training.py index 4e9c2691..5b923469 100644 --- a/tests/utils/test_training.py +++ b/tests/utils/test_training.py @@ -113,9 +113,10 @@ def test_m3gnet_training(self, LiFePO4, BaNiO3): model = M3GNet(element_types=element_types, is_intensive=False) lit_model = PotentialLightningModule(model=model, stress_weight=0.0001) # We will use CPU if MPS is available since there is a serious bug. - trainer = pl.Trainer(max_epochs=5, accelerator=device) + trainer = pl.Trainer(max_epochs=5, accelerator=device, inference_mode=False) trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader) + trainer.test(lit_model, dataloaders=test_loader) pred_LFP_energy = model.predict_structure(LiFePO4) pred_BNO_energy = model.predict_structure(BaNiO3)