From cddc6b6ad0ce5c9ef1d4b3acd77547a2c30ecbc4 Mon Sep 17 00:00:00 2001 From: Filippo Bigi <98903385+frostedoyster@users.noreply.github.com> Date: Mon, 4 Mar 2024 06:41:50 +0100 Subject: [PATCH] Rename labels (#42) --- examples/alchemical_model.py | 2 +- examples/power_spectrum.py | 4 +- examples/ps_model.py | 2 +- tests/compare_vs_rascaline.py | 6 +-- tests/plot_radial_basis.py | 2 +- tests/test_finite_differences.py | 2 +- tests/test_spherical_expansions.py | 64 +++++++++++++++++++++++++++--- torch_spex/radial_basis.py | 6 +-- torch_spex/spherical_expansions.py | 33 ++++++++------- 9 files changed, 86 insertions(+), 35 deletions(-) diff --git a/examples/alchemical_model.py b/examples/alchemical_model.py index f881554..5ca8c45 100644 --- a/examples/alchemical_model.py +++ b/examples/alchemical_model.py @@ -168,7 +168,7 @@ def forward(self, structure_batch: Dict[str, torch.Tensor], is_training: bool = atomic_energies = [] structure_indices = [] for ai, layer_ai in self.nu2_model.items(): - block = ps.block({"a_i": int(ai)}) + block = ps.block({"center_type": int(ai)}) # print(block.values) features = block.values.squeeze(dim=1) structure_indices.append(block.samples.column("structure")) diff --git a/examples/power_spectrum.py b/examples/power_spectrum.py index 53897db..70de3ac 100644 --- a/examples/power_spectrum.py +++ b/examples/power_spectrum.py @@ -19,7 +19,7 @@ def forward(self, spex: TensorMap): ps_values_ai = [] for l in range(self.l_max+1): cg = (2*l+1)**(-0.5) - block_ai_l = spex.block({"lam": l, "a_i": a_i}) + block_ai_l = spex.block({"o3_lambda": l, "center_type": a_i}) c_ai_l = block_ai_l.values # same as this: @@ -33,7 +33,7 @@ def forward(self, spex: TensorMap): block = TensorBlock( values=ps_values_ai, - samples=spex.block({"lam": 0, "a_i": a_i}).samples, + samples=spex.block({"o3_lambda": 0, "center_type": a_i}).samples, components=[], properties=Labels( "property", diff --git a/examples/ps_model.py b/examples/ps_model.py index ad7d1b9..f27c961 100644 --- a/examples/ps_model.py +++ b/examples/ps_model.py @@ -141,7 +141,7 @@ def forward(self, structures: Dict[str, torch.Tensor], is_training: bool = True) atomic_energies = [] structure_indices = [] for ai, model_ai in self.nu2_model.items(): - block = ps.block({"a_i": int(ai)}) + block = ps.block({"center_type": int(ai)}) features = block.values.squeeze(dim=1) structure_indices.append(block.samples.column("structure")) atomic_energies.append( diff --git a/tests/compare_vs_rascaline.py b/tests/compare_vs_rascaline.py index ae85668..94f02d7 100644 --- a/tests/compare_vs_rascaline.py +++ b/tests/compare_vs_rascaline.py @@ -135,15 +135,15 @@ def function_for_splining_derivative(n, l, r): spherical_expansion_coefficients_rascaline = calculator.compute(structures) all_neighbor_species = Labels( - names=["species_neighbor"], + names=["neighbor_type"], values=np.array(all_species, dtype=np.int32).reshape(-1, 1), ) spherical_expansion_coefficients_rascaline = spherical_expansion_coefficients_rascaline.keys_to_properties(all_neighbor_species) for a_i in all_species: for l in range(l_max+1): - e = spherical_expansion_coefficients_torch_spex.block(lam=l, a_i=a_i).values - n_max_l = spherical_expansion_coefficients_torch_spex.block(lam=l, a_i=a_i).values.shape[2] // len(all_species) + e = spherical_expansion_coefficients_torch_spex.block(o3_lambda=l, center_type=a_i).values + n_max_l = spherical_expansion_coefficients_torch_spex.block(o3_lambda=l, center_type=a_i).values.shape[2] // len(all_species) rascaline_indices = [] for a_i_index in range(len(all_species)): for n in range(n_max_l): diff --git a/tests/plot_radial_basis.py b/tests/plot_radial_basis.py index 17e3267..6b68a2a 100644 --- a/tests/plot_radial_basis.py +++ b/tests/plot_radial_basis.py @@ -39,7 +39,7 @@ def get_dummy_structures(r_array): calculator = SphericalExpansion(hypers_spherical_expansion, [1, 6]) spherical_expansion_coefficients = calculator(**structures) -block_C_0 = spherical_expansion_coefficients.block(a_i = 6, lam = 0) +block_C_0 = spherical_expansion_coefficients.block(center_type = 6, o3_lambda = 0) print("Block shape is", block_C_0.values.shape) block_C_0_0 = block_C_0.values[:, :, 2].flatten().detach().numpy() diff --git a/tests/test_finite_differences.py b/tests/test_finite_differences.py index e720f9f..590b250 100644 --- a/tests/test_finite_differences.py +++ b/tests/test_finite_differences.py @@ -46,7 +46,7 @@ def forward(self, spherical_expansion_kwargs, is_compute_forces=True): spherical_expansion_kwargs["positions"].requires_grad = True if is_compute_forces: spherical_expansion = self.spherical_expansion_calculator(**spherical_expansion_kwargs) - tm = metatensor.torch.sum_over_samples(spherical_expansion, sample_names="center").components_to_properties(["m"]).keys_to_properties(["a_i", "lam", "sigma"]) + tm = metatensor.torch.sum_over_samples(spherical_expansion, sample_names="atom").components_to_properties(["o3_mu"]).keys_to_properties(["center_type", "o3_lambda", "o3_sigma"]) energies = torch.sum(tm.block().values, axis=1) gradient = torch.autograd.grad( diff --git a/tests/test_spherical_expansions.py b/tests/test_spherical_expansions.py index b288f41..4481fc5 100644 --- a/tests/test_spherical_expansions.py +++ b/tests/test_spherical_expansions.py @@ -9,6 +9,57 @@ from torch_spex.structures import InMemoryDataset, TransformerNeighborList, collate_nl from torch.utils.data import DataLoader + +def rename_old_labels(labels: metatensor.torch.Labels): + # The reference values were saved with old names + new_labels = labels + if "l" in labels.names: + new_labels = new_labels.rename("l", "o3_lambda") + if "a_i" in labels.names: + new_labels = new_labels.rename("a_i", "center_type") + if "lam" in new_labels.names: + new_labels = new_labels.rename("lam", "o3_lambda") + if "sigma" in new_labels.names: + new_labels = new_labels.rename("sigma", "o3_sigma") + if "m" in new_labels.names: + new_labels = new_labels.rename("m", "o3_mu") + if "l1" in new_labels.names: + new_labels = new_labels.remove("l1") + if "center" in new_labels.names and "neighbor" not in new_labels.names: + new_labels = new_labels.rename("center", "atom") + if "a1" in new_labels.names: + new_labels = new_labels.rename("a1", "neighbor_type") + if "alphaj" in new_labels.names: + new_labels = new_labels.rename("alphaj", "neighbor_type") + if "alpha_j" in new_labels.names: + new_labels = new_labels.rename("alpha_j", "neighbor_type") + if "n1" in new_labels.names: + new_labels = new_labels.rename("n1", "n") + if "species_center" in new_labels.names: + new_labels = new_labels.rename("species_center", "center_type") + if "species_neighbor" in new_labels.names: + new_labels = new_labels.rename("species_neighbor", "neighbor_type") + if "direction" in new_labels.names: + new_labels = new_labels.rename("direction", "xyz") + return new_labels + +def rename_old_tm(tm: metatensor.torch.TensorMap): + # The reference values were saved with old names + keys = rename_old_labels(tm.keys) + blocks = [] + for block in tm.blocks(): + blocks.append( + metatensor.torch.TensorBlock( + values=block.values, + samples=rename_old_labels(block.samples), + components=[rename_old_labels(component) for component in block.components], + properties=rename_old_labels(block.properties) + ) + ) + + return metatensor.torch.TensorMap(keys=keys, blocks=blocks) + + class TestEthanol1SphericalExpansion: """ Tests on the ethanol1 dataset @@ -41,7 +92,7 @@ def test_vector_expansion_coeffs(self): # Default types are float32 so we cannot get higher accuracy than 1e-7. # Because the reference value have been cacluated using float32 and # now we using float64 computation the accuracy had to be decreased again - assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5) + assert metatensor.torch.allclose(rename_old_tm(tm_ref), tm, atol=1e-5, rtol=1e-5) vector_expansion_script = torch.jit.script(vector_expansion) with torch.no_grad(): @@ -58,7 +109,7 @@ def test_spherical_expansion_coeffs(self): # Default types are float32 so we cannot get higher accuracy than 1e-7. # Because the reference value have been cacluated using float32 and # now we using float64 computation the accuracy had to be decreased again - assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5) + assert metatensor.torch.allclose(rename_old_tm(tm_ref), tm, atol=1e-5, rtol=1e-5) spherical_expansion_script = torch.jit.script(spherical_expansion_calculator) with torch.no_grad(): @@ -88,7 +139,7 @@ def test_spherical_expansion_coeffs_alchemical(self): # Default types are float32 so we cannot get higher accuracy than 1e-7. # Because the reference value have been cacluated using float32 and # now we using float64 computation the accuracy had to be decreased again - assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5) + assert metatensor.torch.allclose(rename_old_tm(tm_ref), tm, atol=1e-5, rtol=1e-5) class TestArtificialSphericalExpansion: """ @@ -116,7 +167,7 @@ def test_vector_expansion_coeffs(self): vector_expansion = VectorExpansion(self.hypers, self.all_species).to(self.device, self.dtype) with torch.no_grad(): tm = metatensor.torch.sort(vector_expansion.forward(**self.batch)) - assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5) + assert metatensor.torch.allclose(rename_old_tm(tm_ref), tm, atol=1e-5, rtol=1e-5) def test_spherical_expansion_coeffs(self): tm_ref = metatensor.torch.load("tests/data/spherical_expansion_coeffs-artificial-data.npz") @@ -126,7 +177,7 @@ def test_spherical_expansion_coeffs(self): tm = spherical_expansion_calculator.forward(**self.batch) # The absolute accuracy is a bit smaller than in the ethanol case # I presume it is because we use 5 frames instead of just one - assert metatensor.torch.allclose(tm_ref, tm, atol=3e-5, rtol=1e-5) + assert metatensor.torch.allclose(rename_old_tm(tm_ref), tm, atol=3e-5, rtol=1e-5) def test_spherical_expansion_coeffs_artificial(self): with open("tests/data/expansion_coeffs-artificial-alchemical-hypers.json", "r") as f: @@ -144,4 +195,5 @@ def test_spherical_expansion_coeffs_artificial(self): ) with torch.no_grad(): tm = spherical_expansion_calculator.forward(**self.batch) - assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5) + print(rename_old_tm(tm_ref).block(0).properties) + assert metatensor.torch.allclose(rename_old_tm(tm_ref), tm, atol=1e-5, rtol=1e-5) diff --git a/torch_spex/radial_basis.py b/torch_spex/radial_basis.py index af552fe..1ff95b7 100644 --- a/torch_spex/radial_basis.py +++ b/torch_spex/radial_basis.py @@ -49,7 +49,7 @@ def __init__(self, hypers, all_species) -> None: torch.nn.Linear(len(all_species), self.n_pseudo_species, bias=False) ) self.species_neighbor_labels = Labels( - names = ["species_neighbor"], + names = ["neighbor_type"], values = torch.tensor(self.all_species, dtype=torch.int).unsqueeze(1) ) else: @@ -86,8 +86,8 @@ def __init__(self, hypers, all_species) -> None: def radial_transform(self, r, samples_metadata: Labels): if self.is_physical: - a_i = samples_metadata.column("species_center") - a_j = samples_metadata.column("species_neighbor") + a_i = samples_metadata.column("center_type") + a_j = samples_metadata.column("neighbor_type") x = r/(0.1+torch.exp(self.lengthscales[a_i])+torch.exp(self.lengthscales[a_j])) return x else: diff --git a/torch_spex/spherical_expansions.py b/torch_spex/spherical_expansions.py index f2caf0d..6e723cc 100644 --- a/torch_spex/spherical_expansions.py +++ b/torch_spex/spherical_expansions.py @@ -70,9 +70,9 @@ class SphericalExpansion(torch.nn.Module): >>> expansion = spherical_expansion(**batch) >>> print(expansion.keys) Labels( - a_i lam sigma - 1 0 1 - 8 0 1 + center_type o3_lambda o3_sigma + 1 0 1 + 8 0 1 ) """ @@ -141,7 +141,7 @@ def forward(self, expanded_vectors = self.vector_expansion_calculator( positions, cells, species, cell_shifts, centers, pairs, structure_centers, structure_pairs, structure_offsets) - samples_metadata = expanded_vectors.block({"l": 0}).samples + samples_metadata = expanded_vectors.block({"o3_lambda": 0}).samples n_species = len(self.all_species) species_to_index = {atomic_number : i_species for i_species, atomic_number in enumerate(self.all_species)} @@ -156,7 +156,7 @@ def forward(self, if self.is_alchemical: density_indices = s_i_metadata_to_unique for l in range(l_max+1): - expanded_vectors_l = expanded_vectors.block({"l": l}).values + expanded_vectors_l = expanded_vectors.block({"o3_lambda": l}).values densities_l = torch.zeros( (n_centers, expanded_vectors_l.shape[1], expanded_vectors_l.shape[2]), dtype = expanded_vectors_l.dtype, @@ -167,12 +167,12 @@ def forward(self, densities.append(densities_l) unique_species = -torch.arange(self.n_pseudo_species, dtype=torch.int64, device=density_indices.device) else: - aj_metadata = samples_metadata.column("species_neighbor") + aj_metadata = samples_metadata.column("neighbor_type") aj_shifts = torch.tensor([species_to_index[int(aj_index)] for aj_index in aj_metadata], dtype=torch.int64, device=aj_metadata.device) density_indices = s_i_metadata_to_unique*n_species+aj_shifts for l in range(l_max+1): - expanded_vectors_l = expanded_vectors.block({"l": l}).values + expanded_vectors_l = expanded_vectors.block({"o3_lambda": l}).values densities_l = torch.zeros( (n_centers*n_species, expanded_vectors_l.shape[1], expanded_vectors_l.shape[2]), dtype = expanded_vectors_l.dtype, @@ -188,7 +188,7 @@ def forward(self, blocks : List[TensorBlock] = [] for l in range(l_max+1): densities_l = densities[l] - vectors_l_block = expanded_vectors.block({"l": l}) + vectors_l_block = expanded_vectors.block({"o3_lambda": l}) vectors_l_block_components = vectors_l_block.components vectors_l_block_n = torch.arange(len(torch.unique(vectors_l_block.properties.column("n"))), dtype=torch.int64, device=species.device) # Need to be smarter to optimize for a_i in self.all_species: @@ -205,17 +205,16 @@ def forward(self, TensorBlock( values = densities_ai_l, samples = Labels( - names = ["structure", "center"], + names = ["structure", "atom"], values = unique_s_i_indices[where_ai] ), components = vectors_l_block_components, properties = Labels( - names = ["a1", "n1", "l1"], + names = ["neighbor_type", "n"], values = torch.stack( [ torch.repeat_interleave(unique_species, vectors_l_block_n.shape[0]), torch.tile(vectors_l_block_n, (unique_species.shape[0],)), - l*torch.ones((densities_ai_l.shape[2],), dtype=torch.int, device=densities_ai_l.device) ], dim=1 ) @@ -225,7 +224,7 @@ def forward(self, spherical_expansion = TensorMap( keys = Labels( - names = ["a_i", "lam", "sigma"], + names = ["center_type", "o3_lambda", "o3_sigma"], values = torch.tensor(labels, dtype=torch.int32, device=species.device) ), blocks = blocks @@ -353,7 +352,7 @@ def forward(self, n_max_l = vector_expansion_l.shape[2] if self.is_alchemical: properties = Labels( - names = ["alpha_j", "n"], + names = ["neighbor_type", "n"], values = torch.stack( [ torch.repeat_interleave(-torch.arange(self.n_pseudo_species, dtype=torch.int64, device=vector_expansion_l.device), n_max_l), @@ -372,7 +371,7 @@ def forward(self, values = vector_expansion_l.reshape(vector_expansion_l.shape[0], 2*l+1, -1), samples = cartesian_vectors.samples, components = [Labels( - names = ("m",), + names = ("o3_mu",), values = torch.arange(start=-l, end=l+1, dtype=torch.int32, device=vector_expansion_l.device).reshape(2*l+1, 1) )], properties = properties.to(vector_expansion_l.device) @@ -382,7 +381,7 @@ def forward(self, l_max = len(vector_expansion_blocks) - 1 vector_expansion_tmap = TensorMap( keys = Labels( - names = ("l",), + names = ("o3_lambda",), values = torch.arange(start=0, end=l_max+1, dtype=torch.int32, device=vector_expansion_blocks[0].values.device).reshape(l_max+1, 1), ), blocks = vector_expansion_blocks @@ -421,12 +420,12 @@ def get_cartesian_vectors(positions, cells, species, cell_shifts, centers, pairs block = TensorBlock( values = direction_vectors.unsqueeze(dim=-1), samples = Labels( - names = ["structure", "center", "neighbor", "species_center", "species_neighbor", "cell_x", "cell_y", "cell_z"], + names = ["structure", "center", "neighbor", "center_type", "neighbor_type", "cell_x", "cell_y", "cell_z"], values = labels ), components = [ Labels( - names = ["cartesian_dimension"], + names = ["xyz"], values = torch.tensor([-1, 0, 1], dtype=torch.int32, device=direction_vectors.device).reshape((-1, 1)) ) ],