From b7c75313144dab6b783d37e72e5119cd30671371 Mon Sep 17 00:00:00 2001 From: Filippo Bigi <98903385+frostedoyster@users.noreply.github.com> Date: Thu, 8 Feb 2024 22:36:46 +0100 Subject: [PATCH] Remove static functions (#41) --- examples/power_spectrum.py | 5 ++++- tests/test_spherical_expansions.py | 12 ++++++------ torch_spex/radial_basis.py | 5 ++++- torch_spex/spherical_expansions.py | 12 +++++++++--- torch_spex/splines.py | 4 ++++ 5 files changed, 27 insertions(+), 11 deletions(-) diff --git a/examples/power_spectrum.py b/examples/power_spectrum.py index c7a185b..cc6f943 100644 --- a/examples/power_spectrum.py +++ b/examples/power_spectrum.py @@ -35,7 +35,10 @@ def forward(self, spex: TensorMap): values=ps_values_ai, samples=spex.block({"lam": 0, "a_i": a_i}).samples, components=[], - properties=Labels.range("property", ps_values_ai.shape[-1]) + properties=Labels( + "property", + torch.arange(ps_values_ai.shape[-1], device=ps_values_ai.device).reshape(-1, 1) + ) ) keys.append([a_i]) blocks.append(block) diff --git a/tests/test_spherical_expansions.py b/tests/test_spherical_expansions.py index 1f4cfb9..b288f41 100644 --- a/tests/test_spherical_expansions.py +++ b/tests/test_spherical_expansions.py @@ -31,7 +31,7 @@ class TestEthanol1SphericalExpansion: def test_vector_expansion_coeffs(self): tm_ref = metatensor.torch.load("tests/data/vector_expansion_coeffs-ethanol1_0-data.npz") - tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype) + tm_ref = tm_ref.to(device=self.device, dtype=self.dtype) # we need to sort both computed and reference pair expansion coeffs, # because ase.neighborlist can get different neighborlist order for some reasons tm_ref = metatensor.torch.sort(tm_ref) @@ -51,7 +51,7 @@ def test_vector_expansion_coeffs(self): def test_spherical_expansion_coeffs(self): tm_ref = metatensor.torch.load("tests/data/spherical_expansion_coeffs-ethanol1_0-data.npz") - tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype) + tm_ref = tm_ref.to(device=self.device, dtype=self.dtype) spherical_expansion_calculator = SphericalExpansion(self.hypers, self.all_species).to(self.device, self.dtype) with torch.no_grad(): tm = spherical_expansion_calculator.forward(**self.batch) @@ -70,7 +70,7 @@ def test_spherical_expansion_coeffs_alchemical(self): with open("tests/data/expansion_coeffs-ethanol1_0-alchemical-hypers.json", "r") as f: hypers = json.load(f) tm_ref = metatensor.torch.load("tests/data/spherical_expansion_coeffs-ethanol1_0-alchemical-seed0-data.npz") - tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype) + tm_ref = tm_ref.to(device=self.device, dtype=self.dtype) torch.manual_seed(0) spherical_expansion_calculator = SphericalExpansion(hypers, self.all_species).to(self.device, self.dtype) # Because setting seed seems not be enough to get the same initial combination matrix @@ -111,7 +111,7 @@ class TestArtificialSphericalExpansion: def test_vector_expansion_coeffs(self): tm_ref = metatensor.torch.load("tests/data/vector_expansion_coeffs-artificial-data.npz") - tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype) + tm_ref = tm_ref.to(device=self.device, dtype=self.dtype) tm_ref = metatensor.torch.sort(tm_ref) vector_expansion = VectorExpansion(self.hypers, self.all_species).to(self.device, self.dtype) with torch.no_grad(): @@ -120,7 +120,7 @@ def test_vector_expansion_coeffs(self): def test_spherical_expansion_coeffs(self): tm_ref = metatensor.torch.load("tests/data/spherical_expansion_coeffs-artificial-data.npz") - tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype) + tm_ref = tm_ref.to(device=self.device, dtype=self.dtype) spherical_expansion_calculator = SphericalExpansion(self.hypers, self.all_species).to(self.device, self.dtype) with torch.no_grad(): tm = spherical_expansion_calculator.forward(**self.batch) @@ -132,7 +132,7 @@ def test_spherical_expansion_coeffs_artificial(self): with open("tests/data/expansion_coeffs-artificial-alchemical-hypers.json", "r") as f: hypers = json.load(f) tm_ref = metatensor.torch.load("tests/data/spherical_expansion_coeffs-artificial-alchemical-seed0-data.npz") - tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype) + tm_ref = tm_ref.to(device=self.device, dtype=self.dtype) spherical_expansion_calculator = SphericalExpansion(hypers, self.all_species).to(self.device, self.dtype) with torch.no_grad(): spherical_expansion_calculator.vector_expansion_calculator.radial_basis_calculator.combination_matrix.weight.copy_( diff --git a/torch_spex/radial_basis.py b/torch_spex/radial_basis.py index 09ca8bc..e3cebf8 100644 --- a/torch_spex/radial_basis.py +++ b/torch_spex/radial_basis.py @@ -56,7 +56,10 @@ def __init__(self, hypers, all_species) -> None: self.is_alchemical = False self.n_pseudo_species = 0 # dummy for torchscript self.combination_matrix = torch.nn.Linear(1, 1) # dummy for torchscript - self.species_neighbor_labels = Labels.empty("dummy") + self.species_neighbor_labels = Labels( + names=["dummy"], + values=torch.empty((0, 1), dtype=torch.int) + ) self.apply_mlp = False if hypers["mlp"]: diff --git a/torch_spex/spherical_expansions.py b/torch_spex/spherical_expansions.py index cafa789..cddf156 100644 --- a/torch_spex/spherical_expansions.py +++ b/torch_spex/spherical_expansions.py @@ -67,7 +67,7 @@ class SphericalExpansion(torch.nn.Module): >>> loader = DataLoader(dataset, batch_size=1, collate_fn=collate_nl) >>> batch = next(iter(loader)) >>> spherical_expansion = SphericalExpansion(hypers, [1, 8]) - >>> expansion = spherical_expansion.forward(**batch) + >>> expansion = spherical_expansion(**batch) >>> print(expansion.keys) Labels( a_i lam sigma @@ -356,7 +356,10 @@ def forward(self, ) ) else: - properties = Labels.range("n", n_max_l) + properties = Labels( + names=["n"], + values = torch.arange(n_max_l, device=vector_expansion_l.device).reshape(n_max_l, 1) + ) vector_expansion_blocks.append( TensorBlock( values = vector_expansion_l.reshape(vector_expansion_l.shape[0], 2*l+1, -1), @@ -420,7 +423,10 @@ def get_cartesian_vectors(positions, cells, species, cell_shifts, centers, pairs values = torch.tensor([-1, 0, 1], dtype=torch.int32, device=direction_vectors.device).reshape((-1, 1)) ) ], - properties = Labels.single().to(direction_vectors.device) + properties = Labels( + names=["_"], + values=torch.zeros((1, 1), dtype=torch.int, device=direction_vectors.device) + ) ) return block diff --git a/torch_spex/splines.py b/torch_spex/splines.py index cdc4ab0..fcc1bbb 100644 --- a/torch_spex/splines.py +++ b/torch_spex/splines.py @@ -126,6 +126,10 @@ def __init__( self.spline_values = concatenated_values[sort_indices] self.spline_derivatives = concatenated_derivatives[sort_indices] + self.spline_positions = self.spline_positions.to(torch.get_default_dtype()) + self.spline_values = self.spline_values.to(torch.get_default_dtype()) + self.spline_derivatives = self.spline_derivatives.to(torch.get_default_dtype()) + def compute(self, positions): x = positions delta_x = self.spline_positions[1] - self.spline_positions[0]