Skip to content

Commit

Permalink
pt: improve nlist performance (#3425)
Browse files Browse the repository at this point in the history
1. use inv_ex instead of inv. `inv_ex` does not check errors. We can
assume the input is correct.
2. pass CPU box for `torch.arange`;
3. avoid torch.tensor.

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Mar 7, 2024
1 parent fa8e645 commit 2d48d1f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
2 changes: 1 addition & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ def get_data(self, is_train=True, task_key="Default"):
batch_data = next(iter(self.validation_data[task_key]))

for key in batch_data.keys():
if key == "sid" or key == "fid":
if key == "sid" or key == "fid" or key == "box":
continue
elif not isinstance(batch_data[key], list):
if batch_data[key] is not None:
Expand Down
33 changes: 18 additions & 15 deletions deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ def extend_input_and_build_neighbor_list(
):
nframes, nloc = atype.shape[:2]
if box is not None:
box_gpu = box.to(coord.device, non_blocking=True)
coord_normalized = normalize_coord(
coord.view(nframes, nloc, 3),
box.reshape(nframes, 3, 3),
box_gpu.reshape(nframes, 3, 3),
)
else:
box_gpu = None
coord_normalized = coord.clone()
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
coord_normalized, atype, box, rcut
coord_normalized, atype, box_gpu, rcut, box
)
nlist = build_neighbor_list(
extended_coord,
Expand Down Expand Up @@ -262,6 +264,7 @@ def extend_coord_with_ghosts(
atype: torch.Tensor,
cell: Optional[torch.Tensor],
rcut: float,
cell_cpu: Optional[torch.Tensor] = None,
):
"""Extend the coordinates of the atoms by appending peridoc images.
The number of images is large enough to ensure all the neighbors
Expand All @@ -277,6 +280,8 @@ def extend_coord_with_ghosts(
simulation cell tensor of shape [-1, 9].
rcut : float
the cutoff radius
cell_cpu : torch.Tensor
cell on cpu for performance
Returns
-------
Expand All @@ -299,27 +304,25 @@ def extend_coord_with_ghosts(
else:
coord = coord.view([nf, nloc, 3])
cell = cell.view([nf, 3, 3])
cell_cpu = cell_cpu.view([nf, 3, 3]) if cell_cpu is not None else cell
# nf x 3
to_face = to_face_distance(cell)
to_face = to_face_distance(cell_cpu)
# nf x 3
# *2: ghost copies on + and - directions
# +1: central cell
nbuff = torch.ceil(rcut / to_face).to(torch.long)
# 3
nbuff = torch.max(nbuff, dim=0, keepdim=False).values
xi = torch.arange(-nbuff[0], nbuff[0] + 1, 1, device=device)
yi = torch.arange(-nbuff[1], nbuff[1] + 1, 1, device=device)
zi = torch.arange(-nbuff[2], nbuff[2] + 1, 1, device=device)
xyz = xi.view(-1, 1, 1, 1) * torch.tensor(
[1, 0, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device
)
xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor(
[0, 1, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device
)
xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor(
[0, 0, 1], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device
)
nbuff_cpu = nbuff.cpu()
xi = torch.arange(-nbuff_cpu[0], nbuff_cpu[0] + 1, 1, device="cpu")
yi = torch.arange(-nbuff_cpu[1], nbuff_cpu[1] + 1, 1, device="cpu")
zi = torch.arange(-nbuff_cpu[2], nbuff_cpu[2] + 1, 1, device="cpu")
eye_3 = torch.eye(3, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device="cpu")
xyz = xi.view(-1, 1, 1, 1) * eye_3[0]
xyz = xyz + yi.view(1, -1, 1, 1) * eye_3[1]
xyz = xyz + zi.view(1, 1, -1, 1) * eye_3[2]
xyz = xyz.view(-1, 3)
xyz = xyz.to(device=device, non_blocking=True)
# ns x 3
shift_idx = xyz[torch.argsort(torch.norm(xyz, dim=1))]
ns, _ = shift_idx.shape
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def phys2inter(
the internal coordinates
"""
rec_cell = torch.linalg.inv(cell)
rec_cell, _ = torch.linalg.inv_ex(cell)
return torch.matmul(coord, rec_cell)


Expand Down

0 comments on commit 2d48d1f

Please sign in to comment.