Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: remove env.DEVICE in all forward functions #3330

Merged
merged 3 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,14 @@
"""Get the sels for each individual models."""
return [model.get_sel() for model in self.models]

def _sort_rcuts_sels(self) -> Tuple[List[float], List[int]]:
def _sort_rcuts_sels(self, device: torch.device) -> Tuple[List[float], List[int]]:

Check warning on line 95 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L95

Added line #L95 was not covered by tests
# sort the pair of rcut and sels in ascending order, first based on sel, then on rcut.
rcuts = torch.tensor(
self.get_model_rcuts(), dtype=torch.float64, device=env.DEVICE
)
nsels = torch.tensor(self.get_model_nsels(), device=env.DEVICE)
rcuts = torch.tensor(self.get_model_rcuts(), dtype=torch.float64, device=device)
nsels = torch.tensor(self.get_model_nsels(), device=device)

Check warning on line 98 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L97-L98

Added lines #L97 - L98 were not covered by tests
zipped = torch.stack(
[
torch.tensor(rcuts, device=env.DEVICE),
torch.tensor(nsels, device=env.DEVICE),
torch.tensor(rcuts, device=device),
torch.tensor(nsels, device=device),
],
dim=0,
).T
Expand Down Expand Up @@ -148,7 +146,7 @@
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
extended_coord = extended_coord.view(nframes, -1, 3)
sorted_rcuts, sorted_sels = self._sort_rcuts_sels()
sorted_rcuts, sorted_sels = self._sort_rcuts_sels(device=extended_coord.device)

Check warning on line 149 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L149

Added line #L149 was not covered by tests
nlists = build_multiple_neighbor_list(
extended_coord,
nlist,
Expand Down
9 changes: 4 additions & 5 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
FittingOutputDef,
OutputVariableDef,
)
from deepmd.pt.utils import (
env,
)
from deepmd.utils.pair_tab import (
PairTab,
)
Expand Down Expand Up @@ -160,15 +157,17 @@
pairwise_rr = self._get_pairwise_dist(
extended_coord, masked_nlist
) # (nframes, nloc, nnei)
self.tab_data = self.tab_data.to(device=env.DEVICE).view(
self.tab_data = self.tab_data.to(device=extended_coord.device).view(

Check warning on line 160 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L160

Added line #L160 was not covered by tests
int(self.tab_info[-1]), int(self.tab_info[-1]), int(self.tab_info[2]), 4
)

# to calculate the atomic_energy, we need 3 tensors, i_type, j_type, pairwise_rr
# i_type : (nframes, nloc), this is atype.
# j_type : (nframes, nloc, nnei)
j_type = extended_atype[
torch.arange(extended_atype.size(0), device=env.DEVICE)[:, None, None],
torch.arange(extended_atype.size(0), device=extended_coord.device)[
:, None, None
],
masked_nlist,
]

Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def _update_g1_conv(
else:
gg1 = _apply_switch(gg1, sw)
invnnei = (1.0 / float(nnei)) * torch.ones(
(nb, nloc, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
(nb, nloc, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=gg1.device
)
# nb x nloc x ng2
g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei
Expand Down Expand Up @@ -474,7 +474,7 @@ def _cal_h2g2(
else:
g2 = _apply_switch(g2, sw)
invnnei = (1.0 / float(nnei)) * torch.ones(
(nb, nloc, 1, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
(nb, nloc, 1, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=g2.device
)
# nb x nloc x 3 x ng2
h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,9 @@ def forward(
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = torch.zeros(
[nfnl, 4, self.filter_neuron[-1]], dtype=self.prec, device=env.DEVICE
[nfnl, 4, self.filter_neuron[-1]],
dtype=self.prec,
device=extended_coord.device,
)
# nfnl x nnei
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, -1)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def _forward_common(
outs = torch.zeros(
(nf, nloc, net_dim_out),
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
device=descriptor.device,
) # jit assertion
if self.old_impl:
assert self.filter_layers_old is not None
Expand Down
21 changes: 11 additions & 10 deletions deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,9 @@
maping extended index to the local index

"""
device = coord.device

Check warning on line 291 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L291

Added line #L291 was not covered by tests
nf, nloc = atype.shape
aidx = torch.tile(torch.arange(nloc, device=env.DEVICE).unsqueeze(0), [nf, 1])
aidx = torch.tile(torch.arange(nloc, device=device).unsqueeze(0), [nf, 1])

Check warning on line 293 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L293

Added line #L293 was not covered by tests
if cell is None:
nall = nloc
extend_coord = coord.clone()
Expand All @@ -306,17 +307,17 @@
nbuff = torch.ceil(rcut / to_face).to(torch.long)
# 3
nbuff = torch.max(nbuff, dim=0, keepdim=False).values
xi = torch.arange(-nbuff[0], nbuff[0] + 1, 1, device=env.DEVICE)
yi = torch.arange(-nbuff[1], nbuff[1] + 1, 1, device=env.DEVICE)
zi = torch.arange(-nbuff[2], nbuff[2] + 1, 1, device=env.DEVICE)
xi = torch.arange(-nbuff[0], nbuff[0] + 1, 1, device=device)
yi = torch.arange(-nbuff[1], nbuff[1] + 1, 1, device=device)
zi = torch.arange(-nbuff[2], nbuff[2] + 1, 1, device=device)

Check warning on line 312 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L310-L312

Added lines #L310 - L312 were not covered by tests
xyz = xi.view(-1, 1, 1, 1) * torch.tensor(
[1, 0, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
[1, 0, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device
)
xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor(
[0, 1, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
[0, 1, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device
)
xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor(
[0, 0, 1], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
[0, 0, 1], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device
)
xyz = xyz.view(-1, 3)
# ns x 3
Expand All @@ -333,7 +334,7 @@
extend_aidx = torch.tile(aidx.unsqueeze(-2), [1, ns, 1])

return (
extend_coord.reshape([nf, nall * 3]).to(env.DEVICE),
extend_atype.view([nf, nall]).to(env.DEVICE),
extend_aidx.view([nf, nall]).to(env.DEVICE),
extend_coord.reshape([nf, nall * 3]).to(device),
extend_atype.view([nf, nall]).to(device),
extend_aidx.view([nf, nall]).to(device),
)
8 changes: 8 additions & 0 deletions source/tests/pt/model/test_deeppot.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,11 @@ def setUp(self):
)
freeze(ns)
self.model = frozen_model

# Note: this can not actually disable cuda device to be used
# only can be used to test whether devices are mismatched
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.mock.patch("deepmd.pt.utils.env.DEVICE", torch.device("cpu"))
@unittest.mock.patch("deepmd.pt.infer.deep_eval.DEVICE", torch.device("cpu"))
def test_dp_test_cpu(self):
self.test_dp_test()
Loading