diff --git a/deepmd/dpmodel/model/dp_atomic_model.py b/deepmd/dpmodel/model/dp_atomic_model.py index 63c44aa1f8..4bb6cb1daf 100644 --- a/deepmd/dpmodel/model/dp_atomic_model.py +++ b/deepmd/dpmodel/model/dp_atomic_model.py @@ -62,6 +62,10 @@ def get_sel(self) -> List[int]: """Get the neighbor selection.""" return self.descriptor.get_sel() + def get_type_map(self) -> Optional[List[str]]: + """Get the type map.""" + return self.type_map + def distinguish_types(self) -> bool: """Returns if model requires a neighbor list that distinguish different atomic types or not. diff --git a/deepmd/dpmodel/model/linear_atomic_model.py b/deepmd/dpmodel/model/linear_atomic_model.py index dc7e9996c8..0da40307a6 100644 --- a/deepmd/dpmodel/model/linear_atomic_model.py +++ b/deepmd/dpmodel/model/linear_atomic_model.py @@ -62,6 +62,10 @@ def get_rcut(self) -> float: """Get the cut-off radius.""" return max(self.get_model_rcuts()) + def get_type_map(self) -> Optional[List[str]]: + """Get the type map.""" + raise NotImplementedError("TODO: get_type_map should be implemented") + def get_model_rcuts(self) -> List[float]: """Get the cut-off radius for each individual models.""" return [model.get_rcut() for model in self.models] diff --git a/deepmd/dpmodel/model/make_base_atomic_model.py b/deepmd/dpmodel/model/make_base_atomic_model.py index 84e685b973..080d9982c9 100644 --- a/deepmd/dpmodel/model/make_base_atomic_model.py +++ b/deepmd/dpmodel/model/make_base_atomic_model.py @@ -44,6 +44,10 @@ def get_rcut(self) -> float: """Get the cut-off radius.""" pass + @abstractmethod + def get_type_map(self) -> Optional[List[str]]: + """Get the type map.""" + @abstractmethod def get_sel(self) -> List[int]: """Returns the number of selected atoms for each type.""" diff --git a/deepmd/dpmodel/model/pairtab_atomic_model.py b/deepmd/dpmodel/model/pairtab_atomic_model.py index d4feb970fb..12073c7f63 100644 --- a/deepmd/dpmodel/model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/model/pairtab_atomic_model.py @@ -82,6 +82,9 @@ def fitting_output_def(self) -> FittingOutputDef: def get_rcut(self) -> float: return self.rcut + def get_type_map(self) -> Optional[List[str]]: + raise NotImplementedError("TODO: get_type_map should be implemented") + def get_sel(self) -> List[int]: return [self.sel] diff --git a/deepmd/pt/model/model/dp_atomic_model.py b/deepmd/pt/model/model/dp_atomic_model.py index 273e79b86b..89b814edaa 100644 --- a/deepmd/pt/model/model/dp_atomic_model.py +++ b/deepmd/pt/model/model/dp_atomic_model.py @@ -67,10 +67,16 @@ def fitting_output_def(self) -> FittingOutputDef: else self.coord_denoise_net.output_def() ) + @torch.jit.export def get_rcut(self) -> float: """Get the cut-off radius.""" return self.rcut + @torch.jit.export + def get_type_map(self) -> List[str]: + """Get the type map.""" + return self.type_map + def get_sel(self) -> List[int]: """Get the neighbor selection.""" return self.sel diff --git a/deepmd/pt/model/model/linear_atomic_model.py b/deepmd/pt/model/model/linear_atomic_model.py index 8b50f5e4f5..0d54e4c091 100644 --- a/deepmd/pt/model/model/linear_atomic_model.py +++ b/deepmd/pt/model/model/linear_atomic_model.py @@ -61,10 +61,16 @@ def distinguish_types(self) -> bool: """If distinguish different types by sorting.""" return False + @torch.jit.export def get_rcut(self) -> float: """Get the cut-off radius.""" return max(self.get_model_rcuts()) + @torch.jit.export + def get_type_map(self) -> List[str]: + """Get the type map.""" + raise NotImplementedError("TODO: implement this method") + def get_model_rcuts(self) -> List[float]: """Get the cut-off radius for each individual models.""" return [model.get_rcut() for model in self.models] diff --git a/deepmd/pt/model/model/pairtab_atomic_model.py b/deepmd/pt/model/model/pairtab_atomic_model.py index 2837aaffe7..0ef1448398 100644 --- a/deepmd/pt/model/model/pairtab_atomic_model.py +++ b/deepmd/pt/model/model/pairtab_atomic_model.py @@ -95,9 +95,14 @@ def fitting_output_def(self) -> FittingOutputDef: ] ) + @torch.jit.export def get_rcut(self) -> float: return self.rcut + @torch.jit.export + def get_type_map(self) -> Optional[List[str]]: + raise NotImplementedError("TODO: implement this method") + def get_sel(self) -> List[int]: return [self.sel] diff --git a/source/tests/pt/model/test_dp_atomic_model.py b/source/tests/pt/model/test_dp_atomic_model.py index ef25e574d4..fb7e684eaa 100644 --- a/source/tests/pt/model/test_dp_atomic_model.py +++ b/source/tests/pt/model/test_dp_atomic_model.py @@ -107,4 +107,6 @@ def test_jit(self): ).to(env.DEVICE) type_map = ["foo", "bar"] md0 = DPAtomicModel(ds, ft, type_map=type_map).to(env.DEVICE) - torch.jit.script(md0) + md0 = torch.jit.script(md0) + self.assertEqual(md0.get_rcut(), self.rcut) + self.assertEqual(md0.get_type_map(), type_map) diff --git a/source/tests/pt/model/test_dp_model.py b/source/tests/pt/model/test_dp_model.py index 6e009d3934..f3f899fbe2 100644 --- a/source/tests/pt/model/test_dp_model.py +++ b/source/tests/pt/model/test_dp_model.py @@ -284,7 +284,9 @@ def test_jit(self): ).to(env.DEVICE) type_map = ["foo", "bar"] md0 = DPModel(ds, ft, type_map=type_map).to(env.DEVICE) - torch.jit.script(md0) + md0 = torch.jit.script(md0) + md0.get_rcut() + md0.get_type_map() class TestDPModelFormatNlist(unittest.TestCase): @@ -521,4 +523,6 @@ def test_jit(self): ).to(env.DEVICE) type_map = ["foo", "bar"] md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE) - torch.jit.script(md0) + md0 = torch.jit.script(md0) + self.assertEqual(md0.get_rcut(), self.rcut) + self.assertEqual(md0.get_type_map(), type_map) diff --git a/source/tests/pt/model/test_linear_atomic_model.py b/source/tests/pt/model/test_linear_atomic_model.py index e0247f911f..14fc6b386a 100644 --- a/source/tests/pt/model/test_linear_atomic_model.py +++ b/source/tests/pt/model/test_linear_atomic_model.py @@ -171,8 +171,14 @@ def test_self_consistency(self): ) def test_jit(self): - torch.jit.script(self.md1) - torch.jit.script(self.md3) + md1 = torch.jit.script(self.md1) + self.assertEqual(md1.get_rcut(), self.rcut) + with self.assertRaises(torch.jit.Error): + self.assertEqual(md1.get_type_map(), ["foo", "bar"]) + md3 = torch.jit.script(self.md3) + self.assertEqual(md3.get_rcut(), self.rcut) + with self.assertRaises(torch.jit.Error): + self.assertEqual(md3.get_type_map(), ["foo", "bar"]) if __name__ == "__main__": diff --git a/source/tests/pt/model/test_pairtab_atomic_model.py b/source/tests/pt/model/test_pairtab_atomic_model.py index 23718c134a..f58ac76211 100644 --- a/source/tests/pt/model/test_pairtab_atomic_model.py +++ b/source/tests/pt/model/test_pairtab_atomic_model.py @@ -82,6 +82,9 @@ def test_with_mask(self): def test_jit(self): model = torch.jit.script(self.model) + self.assertEqual(model.get_rcut(), 0.02) + with self.assertRaises(torch.jit.Error): + self.assertEqual(model.get_type_map(), None) def test_deserialize(self): model1 = PairTabModel.deserialize(self.model.serialize()) @@ -101,6 +104,9 @@ def test_deserialize(self): ) model1 = torch.jit.script(model1) + self.assertEqual(model1.get_rcut(), 0.02) + with self.assertRaises(torch.jit.Error): + self.assertEqual(model1.get_type_map(), None) def test_cross_deserialize(self): model_dict = self.model.serialize() # pytorch model to dict