diff --git a/ocpmodels/common/utils.py b/ocpmodels/common/utils.py index 53b497e32d..8c47b0ab77 100644 --- a/ocpmodels/common/utils.py +++ b/ocpmodels/common/utils.py @@ -969,7 +969,7 @@ def check_traj_files(batch, traj_dir) -> bool: if traj_dir is None: return False traj_dir = Path(traj_dir) - traj_files = [traj_dir / f"{id}.traj" for id in batch[0].sid.tolist()] + traj_files = [traj_dir / f"{id}.traj" for id in batch.sid.tolist()] return all(fl.exists() for fl in traj_files) diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index 1ef82baf52..86bab1a9d7 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -316,6 +316,7 @@ def _compute_loss(self, out, batch): target = batch[target_name] pred = out[target_name] + natoms = batch.natoms natoms = torch.repeat_interleave(natoms, natoms) @@ -580,7 +581,7 @@ def run_relaxations(self, split="val"): if check_traj_files( batch, self.config["task"]["relax_opt"].get("traj_dir", None) ): - logging.info(f"Skipping batch: {batch[0].sid.tolist()}") + logging.info(f"Skipping batch: {batch.sid.tolist()}") continue relaxed_batch = ml_relax(