Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Mar 28, 2024
1 parent 9729b34 commit 556a15c
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/alchemical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self, hypers, all_species, do_forces) -> None:
super().__init__()
self.all_species = all_species
self.spherical_expansion_calculator = SphericalExpansion(hypers, all_species)
n_max = self.spherical_expansion_calculator.vector_expansion_calculator.radial_basis_calculator.n_max_l
n_max = self.spherical_expansion_calculator.radial_basis_calculator.n_max_l
print("Radial basis:", n_max)
l_max = len(n_max) - 1
n_feat = sum([n_max[l]**2 * n_pseudo**2 for l in range(l_max+1)])
Expand Down
2 changes: 1 addition & 1 deletion tests/data/computing_ref_coeffs-artificial.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
spherical_expansion_calculator = SphericalExpansion(hypers, all_species)
# some random combination matrix, it is only important that we use the same one in the tests
with torch.no_grad():
spherical_expansion_calculator.vector_expansion_calculator.radial_basis_calculator.combination_matrix.weight.copy_(
spherical_expansion_calculator.radial_basis_calculator.combination_matrix.weight.copy_(
torch.tensor(
[[-0.00432252, 0.30971584, -0.47518533],
[-0.4248946 , -0.22236897, 0.15482073]],
Expand Down
4 changes: 2 additions & 2 deletions tests/test_spherical_expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_spherical_expansion_coeffs_alchemical(self):
with torch.no_grad():
# wtf? suggested way by torch developers
# https://discuss.pytorch.org/t/initialize-nn-linear-with-specific-weights/29005/4
spherical_expansion_calculator.vector_expansion_calculator.radial_basis_calculator.combination_matrix.weight.copy_(torch.tensor(
spherical_expansion_calculator.radial_basis_calculator.combination_matrix.weight.copy_(torch.tensor(
[[-0.00432252, 0.30971584, -0.47518533],
[-0.4248946 , -0.22236897, 0.15482073]],
device=self.device, dtype=self.dtype))
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_spherical_expansion_coeffs_artificial(self):
tm_ref = tm_ref.to(device=self.device, dtype=self.dtype)
spherical_expansion_calculator = SphericalExpansion(hypers, self.all_species).to(self.device, self.dtype)
with torch.no_grad():
spherical_expansion_calculator.vector_expansion_calculator.radial_basis_calculator.combination_matrix.weight.copy_(
spherical_expansion_calculator.radial_basis_calculator.combination_matrix.weight.copy_(
torch.tensor(
[[-0.00432252, 0.30971584, -0.47518533],
[-0.4248946 , -0.22236897, 0.15482073]],
Expand Down

0 comments on commit 556a15c

Please sign in to comment.