From c2ce382fc4402043cf728518c847dfb83fbce97d Mon Sep 17 00:00:00 2001 From: Matthew Evans Date: Fri, 29 Mar 2024 11:34:35 +0000 Subject: [PATCH] Refactor tests to allow new featurizer columns to exist as long as old ones are present --- modnet/tests/test_preprocessing.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/modnet/tests/test_preprocessing.py b/modnet/tests/test_preprocessing.py index e27fa47..51c294c 100644 --- a/modnet/tests/test_preprocessing.py +++ b/modnet/tests/test_preprocessing.py @@ -12,8 +12,14 @@ def check_column_values(new: MODData, reference: MODData, tolerance=0.03): Allows for some columns to be checked more loosely (see inline comment below). """ + new_cols = set(new.df_featurized.columns) + old_cols = set(reference.df_featurized.columns) + + # Check that the new df only adds new columns and is not missing anything + assert not (old_cols - new_cols) + error_cols = set() - for col in new.df_featurized.columns: + for col in old_cols: if not ( np.absolute( ( @@ -349,14 +355,6 @@ def test_small_moddata_featurization(small_moddata_2023, featurizer_mode): featurizer.featurizer_mode = featurizer_mode new = MODData(structures, targets, target_names=names, featurizer=featurizer) new.featurize(fast=False, n_jobs=1) - - new_cols = sorted(new.df_featurized.columns.tolist()) - old_cols = sorted(old.df_featurized.columns.tolist()) - - for i in range(len(old_cols)): - assert new_cols[i] == old_cols[i] - - np.testing.assert_array_equal(old_cols, new_cols) check_column_values(new, old, tolerance=0.03) @@ -376,13 +374,6 @@ def test_small_moddata_composition_featurization( new = MODData(materials=compositions, featurizer=featurizer) new.featurize(fast=False, n_jobs=1) - new_cols = sorted(new.df_featurized.columns.tolist()) - ref_cols = sorted(reference.df_featurized.columns.tolist()) - - for i in range(len(ref_cols)): - # print(new_cols[i], ref_cols[i]) - assert new_cols[i] == ref_cols[i] - # assert relative error below 3 percent check_column_values(new, reference, tolerance=0.03)