Skip to content

Commit

Permalink
ut: add null test (#3391)
Browse files Browse the repository at this point in the history
test the cases: 
1. system only has one atom
2. system has two atoms that are far away from each other. 

In each cases, the force and virial predictions should be zero and
energy should be a valid float.

---------

Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
  • Loading branch information
wanghan-iapcm and Han Wang authored Mar 3, 2024
1 parent da014e7 commit 4f933d8
Showing 1 changed file with 145 additions and 0 deletions.
145 changes: 145 additions & 0 deletions source/tests/pt/model/test_null_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import unittest

import numpy as np
import torch

from deepmd.pt.infer.deep_eval import (
eval_model,
)
from deepmd.pt.model.model import (
get_model,
get_zbl_model,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)

from .test_permutation import (
model_dpa1,
model_dpa2,
model_hybrid,
model_se_e2_a,
model_zbl,
)

dtype = torch.float64


class NullTest:
def test_nloc_1(
self,
):
natoms = 1
# torch.manual_seed(1000)
cell = torch.rand([3, 3], dtype=dtype, device=env.DEVICE)
# large box to exclude images
cell = (cell + cell.T) + 100.0 * torch.eye(3, device=env.DEVICE)
coord = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE)
atype = torch.tensor([0], dtype=torch.int32, device=env.DEVICE)
e0, f0, v0 = eval_model(
self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype
)
ret0 = {
"energy": e0.squeeze(0),
"force": f0.squeeze(0),
"virial": v0.squeeze(0),
}
prec = 1e-10
expect_e_shape = [1]
expect_f = torch.zeros([natoms, 3], dtype=dtype, device=env.DEVICE)
expect_v = torch.zeros([9], dtype=dtype, device=env.DEVICE)
self.assertEqual(list(ret0["energy"].shape), expect_e_shape)
self.assertFalse(np.isnan(to_numpy_array(ret0["energy"])[0]))
torch.testing.assert_close(ret0["force"], expect_f, rtol=prec, atol=prec)
if not hasattr(self, "test_virial") or self.test_virial:
torch.testing.assert_close(ret0["virial"], expect_v, rtol=prec, atol=prec)

def test_nloc_2_far(
self,
):
natoms = 2
cell = torch.rand([3, 3], dtype=dtype, device=env.DEVICE)
# large box to exclude images
cell = (cell + cell.T) + 3000.0 * torch.eye(3, device=env.DEVICE)
coord = torch.rand([1, 3], dtype=dtype, device=env.DEVICE)
# 2 far-away atoms
coord = torch.cat([coord, coord + 100.0], dim=0)
atype = torch.tensor([0, 2], dtype=torch.int32, device=env.DEVICE)
e0, f0, v0 = eval_model(
self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype
)
ret0 = {
"energy": e0.squeeze(0),
"force": f0.squeeze(0),
"virial": v0.squeeze(0),
}
prec = 1e-10
expect_e_shape = [1]
expect_f = torch.zeros([natoms, 3], dtype=dtype, device=env.DEVICE)
expect_v = torch.zeros([9], dtype=dtype, device=env.DEVICE)
self.assertEqual(list(ret0["energy"].shape), expect_e_shape)
self.assertFalse(np.isnan(to_numpy_array(ret0["energy"])[0]))
torch.testing.assert_close(ret0["force"], expect_f, rtol=prec, atol=prec)
if not hasattr(self, "test_virial") or self.test_virial:
torch.testing.assert_close(ret0["virial"], expect_v, rtol=prec, atol=prec)


class TestEnergyModelSeA(unittest.TestCase, NullTest):
def setUp(self):
model_params = copy.deepcopy(model_se_e2_a)
self.type_split = False
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelDPA1(unittest.TestCase, NullTest):
def setUp(self):
model_params = copy.deepcopy(model_dpa1)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelDPA2(unittest.TestCase, NullTest):
def setUp(self):
model_params = copy.deepcopy(model_dpa2)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)


class TestForceModelDPA2(unittest.TestCase, NullTest):
def setUp(self):
model_params = copy.deepcopy(model_dpa2)
model_params["fitting_net"]["type"] = "direct_force_ener"
self.type_split = True
self.test_virial = False
self.model = get_model(model_params).to(env.DEVICE)


@unittest.skip("hybrid not supported at the moment")
class TestEnergyModelHybrid(unittest.TestCase, NullTest):
def setUp(self):
model_params = copy.deepcopy(model_hybrid)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)


@unittest.skip("hybrid not supported at the moment")
class TestForceModelHybrid(unittest.TestCase, NullTest):
def setUp(self):
model_params = copy.deepcopy(model_hybrid)
model_params["fitting_net"]["type"] = "direct_force_ener"
self.type_split = True
self.test_virial = False
self.model = get_model(model_params).to(env.DEVICE)


@unittest.skip("FAILED at the moment")
class TestEnergyModelZBL(unittest.TestCase, NullTest):
def setUp(self):
model_params = copy.deepcopy(model_zbl)
self.type_split = False
self.model = get_zbl_model(model_params).to(env.DEVICE)

0 comments on commit 4f933d8

Please sign in to comment.