Skip to content

Commit

Permalink
add get_type_map method to model; export model methods (#3247)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Feb 12, 2024
1 parent c131c8f commit beb1b98
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 5 deletions.
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def get_sel(self) -> List[int]:
"""Get the neighbor selection."""
return self.descriptor.get_sel()

def get_type_map(self) -> Optional[List[str]]:
"""Get the type map."""
return self.type_map

def distinguish_types(self) -> bool:
"""Returns if model requires a neighbor list that distinguish different
atomic types or not.
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def get_rcut(self) -> float:
"""Get the cut-off radius."""
return max(self.get_model_rcuts())

def get_type_map(self) -> Optional[List[str]]:
"""Get the type map."""
raise NotImplementedError("TODO: get_type_map should be implemented")

def get_model_rcuts(self) -> List[float]:
"""Get the cut-off radius for each individual models."""
return [model.get_rcut() for model in self.models]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def get_rcut(self) -> float:
"""Get the cut-off radius."""
pass

@abstractmethod
def get_type_map(self) -> Optional[List[str]]:
"""Get the type map."""

@abstractmethod
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def fitting_output_def(self) -> FittingOutputDef:
def get_rcut(self) -> float:
return self.rcut

def get_type_map(self) -> Optional[List[str]]:
raise NotImplementedError("TODO: get_type_map should be implemented")

def get_sel(self) -> List[int]:
return [self.sel]

Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,16 @@ def fitting_output_def(self) -> FittingOutputDef:
else self.coord_denoise_net.output_def()
)

@torch.jit.export
def get_rcut(self) -> float:
"""Get the cut-off radius."""
return self.rcut

@torch.jit.export
def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

def get_sel(self) -> List[int]:
"""Get the neighbor selection."""
return self.sel
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,16 @@ def distinguish_types(self) -> bool:
"""If distinguish different types by sorting."""
return False

@torch.jit.export
def get_rcut(self) -> float:
"""Get the cut-off radius."""
return max(self.get_model_rcuts())

@torch.jit.export
def get_type_map(self) -> List[str]:
"""Get the type map."""
raise NotImplementedError("TODO: implement this method")

def get_model_rcuts(self) -> List[float]:
"""Get the cut-off radius for each individual models."""
return [model.get_rcut() for model in self.models]
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,14 @@ def fitting_output_def(self) -> FittingOutputDef:
]
)

@torch.jit.export
def get_rcut(self) -> float:
return self.rcut

@torch.jit.export
def get_type_map(self) -> Optional[List[str]]:
raise NotImplementedError("TODO: implement this method")

def get_sel(self) -> List[int]:
return [self.sel]

Expand Down
4 changes: 3 additions & 1 deletion source/tests/pt/model/test_dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,6 @@ def test_jit(self):
).to(env.DEVICE)
type_map = ["foo", "bar"]
md0 = DPAtomicModel(ds, ft, type_map=type_map).to(env.DEVICE)
torch.jit.script(md0)
md0 = torch.jit.script(md0)
self.assertEqual(md0.get_rcut(), self.rcut)
self.assertEqual(md0.get_type_map(), type_map)
8 changes: 6 additions & 2 deletions source/tests/pt/model/test_dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def test_jit(self):
).to(env.DEVICE)
type_map = ["foo", "bar"]
md0 = DPModel(ds, ft, type_map=type_map).to(env.DEVICE)
torch.jit.script(md0)
md0 = torch.jit.script(md0)
md0.get_rcut()
md0.get_type_map()


class TestDPModelFormatNlist(unittest.TestCase):
Expand Down Expand Up @@ -521,4 +523,6 @@ def test_jit(self):
).to(env.DEVICE)
type_map = ["foo", "bar"]
md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE)
torch.jit.script(md0)
md0 = torch.jit.script(md0)
self.assertEqual(md0.get_rcut(), self.rcut)
self.assertEqual(md0.get_type_map(), type_map)
10 changes: 8 additions & 2 deletions source/tests/pt/model/test_linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,14 @@ def test_self_consistency(self):
)

def test_jit(self):
torch.jit.script(self.md1)
torch.jit.script(self.md3)
md1 = torch.jit.script(self.md1)
self.assertEqual(md1.get_rcut(), self.rcut)
with self.assertRaises(torch.jit.Error):
self.assertEqual(md1.get_type_map(), ["foo", "bar"])
md3 = torch.jit.script(self.md3)
self.assertEqual(md3.get_rcut(), self.rcut)
with self.assertRaises(torch.jit.Error):
self.assertEqual(md3.get_type_map(), ["foo", "bar"])


if __name__ == "__main__":
Expand Down
6 changes: 6 additions & 0 deletions source/tests/pt/model/test_pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def test_with_mask(self):

def test_jit(self):
model = torch.jit.script(self.model)
self.assertEqual(model.get_rcut(), 0.02)
with self.assertRaises(torch.jit.Error):
self.assertEqual(model.get_type_map(), None)

def test_deserialize(self):
model1 = PairTabModel.deserialize(self.model.serialize())
Expand All @@ -101,6 +104,9 @@ def test_deserialize(self):
)

model1 = torch.jit.script(model1)
self.assertEqual(model1.get_rcut(), 0.02)
with self.assertRaises(torch.jit.Error):
self.assertEqual(model1.get_type_map(), None)

def test_cross_deserialize(self):
model_dict = self.model.serialize() # pytorch model to dict
Expand Down

0 comments on commit beb1b98

Please sign in to comment.