Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix errors for zero atom inputs #4005

Merged
merged 3 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,9 @@

for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
out_shape2 = np.prod(out_shape[2:])

Check warning on line 203 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L203

Added line #L203 was not covered by tests
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask
Expand Down
20 changes: 12 additions & 8 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,9 @@
)
# nf x nloc x (ng x ng1 + tebd_dim)
if self.concat_output_tebd:
grrg = np.concatenate([grrg, atype_embd.reshape(nf, nloc, -1)], axis=-1)
grrg = np.concatenate(

Check warning on line 490 in deepmd/dpmodel/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa1.py#L490

Added line #L490 was not covered by tests
[grrg, atype_embd.reshape(nf, nloc, self.tebd_dim)], axis=-1
)
return grrg, rot_mat, None, None, sw

def serialize(self) -> dict:
Expand Down Expand Up @@ -834,7 +836,8 @@
embedding_idx,
):
nfnl, nnei = ss.shape[0:2]
ss = ss.reshape(nfnl, nnei, -1)
shape2 = np.prod(ss.shape[2:])
ss = ss.reshape(nfnl, nnei, shape2)

Check warning on line 840 in deepmd/dpmodel/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa1.py#L839-L840

Added lines #L839 - L840 were not covered by tests
# nfnl x nnei x ng
gg = self.embeddings[embedding_idx].call(ss)
return gg
Expand All @@ -846,7 +849,8 @@
):
assert self.embeddings_strip is not None
nfnl, nnei = ss.shape[0:2]
ss = ss.reshape(nfnl, nnei, -1)
shape2 = np.prod(ss.shape[2:])
ss = ss.reshape(nfnl, nnei, shape2)

Check warning on line 853 in deepmd/dpmodel/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa1.py#L852-L853

Added lines #L852 - L853 were not covered by tests
# nfnl x nnei x ng
gg = self.embeddings_strip[embedding_idx].call(ss)
return gg
Expand Down Expand Up @@ -875,7 +879,7 @@
# nfnl x nnei x 1
sw = sw.reshape(nf * nloc, nnei, 1)
# nfnl x tebd_dim
atype_embd = atype_embd_ext[:, :nloc, :].reshape(nf * nloc, -1)
atype_embd = atype_embd_ext[:, :nloc, :].reshape(nf * nloc, self.tebd_dim)

Check warning on line 882 in deepmd/dpmodel/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa1.py#L882

Added line #L882 was not covered by tests
# nfnl x nnei x tebd_dim
atype_embd_nnei = np.tile(atype_embd[:, np.newaxis, :], (1, nnei, 1))
# nfnl x nnei
Expand Down Expand Up @@ -941,10 +945,10 @@
GLOBAL_NP_FLOAT_PRECISION
)
return (
grrg.reshape(-1, nloc, self.filter_neuron[-1] * self.axis_neuron),
gg.reshape(-1, nloc, self.nnei, self.filter_neuron[-1]),
dmatrix.reshape(-1, nloc, self.nnei, 4)[..., 1:],
gr[..., 1:].reshape(-1, nloc, self.filter_neuron[-1], 3),
grrg.reshape(nf, nloc, self.filter_neuron[-1] * self.axis_neuron),
gg.reshape(nf, nloc, self.nnei, self.filter_neuron[-1]),
dmatrix.reshape(nf, nloc, self.nnei, 4)[..., 1:],
gr[..., 1:].reshape(nf, nloc, self.filter_neuron[-1], 3),
sw,
)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@
h2g2 = _cal_hg(g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon)
# (nf x nloc) x ng2 x 3
rot_mat = np.transpose(h2g2, (0, 1, 3, 2))
return g1, g2, h2, rot_mat.reshape(-1, nloc, self.dim_emb, 3), sw
return g1, g2, h2, rot_mat.reshape(nf, nloc, self.dim_emb, 3), sw

Check warning on line 398 in deepmd/dpmodel/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/repformers.py#L398

Added line #L398 was not covered by tests

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@

res_rescale = 1.0 / 5.0
res = xyz_scatter * res_rescale
res = res.reshape(nf, nloc, -1).astype(GLOBAL_NP_FLOAT_PRECISION)
res = res.reshape(nf, nloc, ng).astype(GLOBAL_NP_FLOAT_PRECISION)

Check warning on line 344 in deepmd/dpmodel/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_r.py#L344

Added line #L344 was not covered by tests
return res, None, None, None, ww

def serialize(self) -> dict:
Expand Down
11 changes: 7 additions & 4 deletions deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,15 @@
nall = coord.shape[1] // 3
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
xmax = np.max(coord) + 2.0 * rcut
if coord.size > 0:
xmax = np.max(coord) + 2.0 * rcut

Check warning on line 99 in deepmd/dpmodel/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/nlist.py#L98-L99

Added lines #L98 - L99 were not covered by tests
else:
xmax = 2.0 * rcut

Check warning on line 101 in deepmd/dpmodel/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/nlist.py#L101

Added line #L101 was not covered by tests
# nf x nall
is_vir = atype < 0
coord1 = np.where(is_vir[:, :, None], xmax, coord.reshape(-1, nall, 3)).reshape(
-1, nall * 3
)
coord1 = np.where(

Check warning on line 104 in deepmd/dpmodel/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/nlist.py#L104

Added line #L104 was not covered by tests
is_vir[:, :, None], xmax, coord.reshape(batch_size, nall, 3)
).reshape(batch_size, nall * 3)
if isinstance(sel, int):
sel = [sel]
nsel = sum(sel)
Expand Down
5 changes: 4 additions & 1 deletion deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,11 @@

for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
out_shape2 = 1
for ss in out_shape[2:]:
out_shape2 *= ss

Check warning on line 261 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L259-L261

Added lines #L259 - L261 were not covered by tests
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
* atom_mask[:, :, None]
).view(out_shape)
ret_dict["mask"] = atom_mask
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@
# (nb x nloc) x ng2 x 3
rot_mat = torch.permute(h2g2, (0, 1, 3, 2))

return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw
return g1, g2, h2, rot_mat.view(nframes, nloc, self.dim_emb, 3), sw

Check warning on line 508 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L508

Added line #L508 was not covered by tests

def compute_input_stats(
self,
Expand Down
7 changes: 4 additions & 3 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@
- `torch.Tensor`: descriptor matrix with shape [nframes, natoms[0]*self.filter_neuron[-1]*self.axis_neuron].
"""
del extended_atype_embd, mapping
nf = nlist.shape[0]

Check warning on line 620 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L620

Added line #L620 was not covered by tests
nloc = nlist.shape[1]
atype = extended_atype[:, :nloc]
dmatrix, diff, sw = prod_env_mat(
Expand Down Expand Up @@ -657,7 +658,7 @@
device=extended_coord.device,
)
# nfnl x nnei
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, -1)
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, self.nnei)

Check warning on line 661 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L661

Added line #L661 was not covered by tests
for embedding_idx, ll in enumerate(self.filter_layers.networks):
if self.type_one_side:
ii = embedding_idx
Expand Down Expand Up @@ -698,8 +699,8 @@
result = torch.matmul(
xyz_scatter_1, xyz_scatter_2
) # shape is [nframes*nall, self.filter_neuron[-1], self.axis_neuron]
result = result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron)
rot_mat = rot_mat.view([-1, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005
result = result.view(nf, nloc, self.filter_neuron[-1] * self.axis_neuron)
rot_mat = rot_mat.view([nf, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005

Check warning on line 703 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L702-L703

Added lines #L702 - L703 were not covered by tests
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
Expand Down
8 changes: 4 additions & 4 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,10 +580,10 @@ def forward(
xyz_scatter_1, xyz_scatter_2
) # shape is [nframes*nloc, self.filter_neuron[-1], self.axis_neuron]
return (
result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron),
gg.view(-1, nloc, self.nnei, self.filter_neuron[-1]),
dmatrix.view(-1, nloc, self.nnei, 4)[..., 1:],
rot_mat.view(-1, nloc, self.filter_neuron[-1], 3),
result.view(nframes, nloc, self.filter_neuron[-1] * self.axis_neuron),
gg.view(nframes, nloc, self.nnei, self.filter_neuron[-1]),
dmatrix.view(nframes, nloc, self.nnei, 4)[..., 1:],
rot_mat.view(nframes, nloc, self.filter_neuron[-1], 3),
sw,
)

Expand Down
5 changes: 3 additions & 2 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@

"""
del mapping
nf = nlist.shape[0]

Check warning on line 346 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L346

Added line #L346 was not covered by tests
njzjz marked this conversation as resolved.
Show resolved Hide resolved
nloc = nlist.shape[1]
atype = atype_ext[:, :nloc]
dmatrix, diff, sw = prod_env_mat(
Expand All @@ -367,7 +368,7 @@
)

# nfnl x nnei
exclude_mask = self.emask(nlist, atype_ext).view(nfnl, -1)
exclude_mask = self.emask(nlist, atype_ext).view(nfnl, self.nnei)

Check warning on line 371 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L371

Added line #L371 was not covered by tests
for ii, ll in enumerate(self.filter_layers.networks):
# nfnl x nt
mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]]
Expand All @@ -381,7 +382,7 @@

res_rescale = 1.0 / 5.0
result = xyz_scatter * res_rescale
result = result.view(-1, nloc, self.filter_neuron[-1])
result = result.view(nf, nloc, self.filter_neuron[-1])

Check warning on line 385 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L385

Added line #L385 was not covered by tests
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
Expand Down
5 changes: 3 additions & 2 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@

"""
del extended_atype_embd, mapping
nf = nlist.shape[0]

Check warning on line 668 in deepmd/pt/model/descriptor/se_t.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_t.py#L668

Added line #L668 was not covered by tests
nloc = nlist.shape[1]
atype = extended_atype[:, :nloc]
dmatrix, diff, sw = prod_env_mat(
Expand All @@ -687,7 +688,7 @@
device=extended_coord.device,
)
# nfnl x nnei
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, -1)
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, self.nnei)

Check warning on line 691 in deepmd/pt/model/descriptor/se_t.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_t.py#L691

Added line #L691 was not covered by tests
for embedding_idx, ll in enumerate(self.filter_layers.networks):
ti = embedding_idx % self.ntypes
nei_type_j = self.sel[ti]
Expand All @@ -714,7 +715,7 @@
res_ij = res_ij * (1.0 / float(nei_type_i) / float(nei_type_j))
result += res_ij
# xyz_scatter /= (self.nnei * self.nnei)
result = result.view(-1, nloc, self.filter_neuron[-1])
result = result.view(nf, nloc, self.filter_neuron[-1])

Check warning on line 718 in deepmd/pt/model/descriptor/se_t.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_t.py#L718

Added line #L718 was not covered by tests
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
Expand Down
12 changes: 10 additions & 2 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@
out_real, out_mag = torch.split(out_tensor, [nloc, nloc], dim=1)
if add_mag:
out_real = out_real + out_mag
out_mag = (out_mag.view([nframes, nloc, -1]) * atomic_mask).view(out_mag.shape)
shape2 = 1
for ss in out_real.shape[2:]:
shape2 *= ss
out_mag = (out_mag.view([nframes, nloc, shape2]) * atomic_mask).view(

Check warning on line 127 in deepmd/pt/model/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/spin_model.py#L124-L127

Added lines #L124 - L127 were not covered by tests
out_mag.shape
)
return out_real, out_mag, atomic_mask > 0.0

def process_spin_output_lower(
Expand Down Expand Up @@ -162,8 +167,11 @@
)
if add_mag:
extended_out_real = extended_out_real + extended_out_mag
shape2 = 1
for ss in extended_out_tensor.shape[2:]:
shape2 *= ss

Check warning on line 172 in deepmd/pt/model/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/spin_model.py#L170-L172

Added lines #L170 - L172 were not covered by tests
extended_out_mag = (
extended_out_mag.view([nframes, nall, -1]) * atomic_mask
extended_out_mag.view([nframes, nall, shape2]) * atomic_mask
).view(extended_out_mag.shape)
return extended_out_real, extended_out_mag, atomic_mask > 0.0

Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@
# (nframes * nloc, 1, m1)
out = out.view(-1, 1, self.embedding_width)
# (nframes * nloc, m1, 3)
gr = gr.view(nframes * nloc, -1, 3)
gr = gr.view(nframes * nloc, self.embedding_width, 3)

Check warning on line 194 in deepmd/pt/model/task/dipole.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dipole.py#L194

Added line #L194 was not covered by tests
# (nframes, nloc, 3)
out = torch.bmm(out, gr).squeeze(-2).view(nframes, nloc, 3)
return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@
self.var_name
]
out = out * (self.scale.to(atype.device))[atype]
gr = gr.view(nframes * nloc, -1, 3) # (nframes * nloc, m1, 3)
gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes * nloc, m1, 3)

Check warning on line 245 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L245

Added line #L245 was not covered by tests

if self.fit_diag:
out = out.reshape(-1, self.embedding_width)
Expand Down
11 changes: 7 additions & 4 deletions deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,15 @@
nall = coord.shape[1] // 3
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
xmax = torch.max(coord) + 2.0 * rcut
if coord.numel() > 0:
xmax = torch.max(coord) + 2.0 * rcut

Check warning on line 103 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L102-L103

Added lines #L102 - L103 were not covered by tests
else:
xmax = torch.zeros(1, dtype=coord.dtype, device=coord.device) + 2.0 * rcut

Check warning on line 105 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L105

Added line #L105 was not covered by tests
# nf x nall
is_vir = atype < 0
coord1 = torch.where(is_vir[:, :, None], xmax, coord.view(-1, nall, 3)).view(
-1, nall * 3
)
coord1 = torch.where(

Check warning on line 108 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L108

Added line #L108 was not covered by tests
is_vir[:, :, None], xmax, coord.view(batch_size, nall, 3)
).view(batch_size, nall * 3)
if isinstance(sel, int):
sel = [sel]
nsel = sum(sel)
Expand Down
103 changes: 103 additions & 0 deletions source/tests/universal/common/cases/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,109 @@ def test_forward(self):
continue
np.testing.assert_allclose(rr1, rr2, atol=aprec)

def test_zero_forward(self):
test_spin = getattr(self, "test_spin", False)
nf = 1
natoms = 0
aprec = (
0
if self.aprec_dict.get("test_forward", None) is None
else self.aprec_dict["test_forward"]
)
rng = np.random.default_rng(GLOBAL_SEED)
coord = np.zeros((nf, 0, 3), dtype=np.float64)
atype = np.zeros([nf, 0], dtype=int)
spin = np.zeros([nf, 0], dtype=np.float64)
cell = 6.0 * np.eye(3, dtype=np.float64).reshape([nf, 9])
coord_ext, atype_ext, mapping, nlist = extend_input_and_build_neighbor_list(
coord,
atype,
self.expected_rcut + 1.0 if test_spin else self.expected_rcut,
self.expected_sel,
mixed_types=self.module.mixed_types(),
box=cell,
)
spin_ext = np.take_along_axis(
spin.reshape(nf, -1, 3),
np.repeat(np.expand_dims(mapping, axis=-1), 3, axis=-1),
axis=1,
)
aparam = None
fparam = None
if self.module.get_dim_aparam() > 0:
aparam = rng.random([nf, natoms, self.module.get_dim_aparam()])
if self.module.get_dim_fparam() > 0:
fparam = rng.random([nf, self.module.get_dim_fparam()])
ret = []
ret_lower = []
for module in self.modules_to_test:
module = self.forward_wrapper(module)
input_dict = {
"coord": coord,
"atype": atype,
"box": cell,
"aparam": aparam,
"fparam": fparam,
}
if test_spin:
input_dict["spin"] = spin
ret.append(module(**input_dict))

input_dict_lower = {
"extended_coord": coord_ext,
"extended_atype": atype_ext,
"nlist": nlist,
"mapping": mapping,
"aparam": aparam,
"fparam": fparam,
}
if test_spin:
input_dict_lower["extended_spin"] = spin_ext

ret_lower.append(module.forward_lower(**input_dict_lower))
for kk in ret[0]:
subret = []
for rr in ret:
if rr is not None:
subret.append(rr[kk])
if len(subret):
for ii, rr in enumerate(subret[1:]):
if subret[0] is None:
assert rr is None
else:
np.testing.assert_allclose(
subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}"
)
for kk in ret_lower[0]:
subret = []
for rr in ret_lower:
if rr is not None:
subret.append(rr[kk])
if len(subret):
for ii, rr in enumerate(subret[1:]):
if subret[0] is None:
assert rr is None
else:
np.testing.assert_allclose(
subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}"
)
same_keys = set(ret[0].keys()) & set(ret_lower[0].keys())
self.assertTrue(same_keys)
for key in same_keys:
for rr in ret:
if rr[key] is not None:
rr1 = rr[key]
break
else:
continue
for rr in ret_lower:
if rr[key] is not None:
rr2 = rr[key]
break
else:
continue
np.testing.assert_allclose(rr1, rr2, atol=aprec)

@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
def test_permutation(self):
"""Test permutation."""
Expand Down