From 0dac9fe4e736b19f2b68e939c0e0fb83a8688fe7 Mon Sep 17 00:00:00 2001 From: qnater Date: Wed, 25 Sep 2024 17:08:13 +0200 Subject: [PATCH] correction of the shap explainer and refactoring --- .idea/workspace.xml | 56 +----------------- .../__pycache__/explainer.cpython-312.pyc | Bin 28190 -> 28184 bytes imputegap/explainer/explainer.py | 12 ++-- .../test_contamination_mcar.cpython-312.pyc | Bin 7947 -> 7944 bytes .../test_explainer.cpython-312.pyc | Bin 4690 -> 4644 bytes tests/test_explainer.py | 21 +++---- 6 files changed, 18 insertions(+), 71 deletions(-) diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 5226c47..6e3a926 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -2,59 +2,9 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + @@ -266,6 +216,6 @@ - + \ No newline at end of file diff --git a/imputegap/explainer/__pycache__/explainer.cpython-312.pyc b/imputegap/explainer/__pycache__/explainer.cpython-312.pyc index 1e48534ea16b80aba643bbd9e14da91cee96f178..d47fba1b9e261af69f6ae5febc0c1fa49b1ce752 100644 GIT binary patch delta 728 zcmYk1Ur19?9LMju`-7XibSErLnTD=ES_Fz@QDlajTb7W4Qd)M~?%l55dpC}E%G9Rz zhkOVeoR`%@?4{a6#N}>5qI2k_-^1s7e&65W_dVxj1*9v0 zTje-AL&|z$PwH6VKD)pHT6K8{w_Py{^pg>qIQ&Z(xV6gw+Z-r5t}fk3nNy*3f3`Mv z^27Ct!P@@AhN+8!?8oam5u$%;n z{6NZ?>P>qx)!E@(B+sGqO>J5TixE5`V|g4;OyTJnJU)wO=e8Lp3A`+k2f#ab9Ihe8W25T;!WGy;)DroSVpZ75sQ8h+9Hyytrebf%_(tO#<6?hHt=C_50PN zGm1@99uPwQ-VLyTn!Q7W3GXsUBAagoJVBp)=lDhP3HXrW5yfL<>iZ3JMye#S8 zF&)PB2z*7%vNa|^L7gmf!KkJSs;mjR9Mxc%rhcQSP<*G@Bq-YzS(WvFmtIrDCdrC8 zrU|MD-_YLmf5w<9=ui$CZ^H{_can^Ktq!8kyW3{i2IRxW z&)@-V91qol;iNs8lGQ{GM&gNt%Wxo;wQ5LA3>O7LBr!2MV>~#A-+$hJ=FNLElPkf~ z5(tZe;Ns}$s^5}ZW`#FZVB^KA0OfjB%56Ib;-;h|kt$)S`rGdAi<=xcCt=AGojRQ@ z%O-Nsytfc~-G?==Kctciy$dmNsgGP%$)Ha93}PfnYJ?<5$>!r6? z4*huA*WpuWaoa7cniNZ>%~aY%G3uZz3~}qbF9>@TC`5~A zCBGB^37=?Rfg5;F$9c-(jt7vzj_@qp#jnDx;!Sp=TMV}u?%-eHACR`TIt73+%tsD@ zX*ET5@X&7^j=rd&r+pbX4pUYzafYwGM^kKvYM7{%C08)iFl6yVIS8wGPByBSsS(;| zdlZyR8;O{S(&kn43?EhMnm@3;C5Dd-pBO$<*e*@gRWoj?NdrA+hV$=N7;#-fdF#HC zD?5^*b~xlgt5myCky3f`n2nV-#Mb^rb8Y9=8fmz;lLGM!D>$yTiDPWadPBb%WL(gi rVG@7Un!E-x*c002Mp{o@L;IOAjSuR6c!W>s2xR>c{>7>qqS* diff --git a/imputegap/explainer/explainer.py b/imputegap/explainer/explainer.py index d28cf5d..0b3fcc3 100644 --- a/imputegap/explainer/explainer.py +++ b/imputegap/explainer/explainer.py @@ -114,8 +114,8 @@ def print(shap_values, shap_details=None): print(f"\tRMSE SERIES {i:<5} : {output:<15}") print("\n\nSHAP Results details : ") - for (x, algo, rate, description, feature, categorie, mean_features) in shap_values: - print(f"\tFeature : {x:<5} {algo:<10} with a score of {rate:<10} {categorie:<18} {description:<75} {feature}\n") + for (x, algo, rate, description, feature, category, mean_features) in shap_values: + print(f"\tFeature : {x:<5} {algo:<10} with a score of {rate:<10} {category:<18} {description:<75} {feature}\n") def convert_results(tmp, file, algo, descriptions, features, categories, mean_features, to_save): """ @@ -145,9 +145,9 @@ def convert_results(tmp, file, algo, descriptions, features, categories, mean_fe print(tup[2], end=",") with open(to_save + "_results.txt", 'w') as file_output: - for (x, algo, rate, description, feature, categorie, mean_features) in result_display: - file_output.write(f"Feature : {x:<5} {algo:<10} with a score of {rate:<10} {categorie:<18} {description:<65} {feature}\n") - result_shap.append([file, algo, rate, description, feature, categorie, mean_features]) + for (x, algo, rate, description, feature, category, mean_features) in result_display: + file_output.write(f"Feature : {x:<5} {algo:<10} with a score of {rate:<10} {category:<18} {description:<65} {feature}\n") + result_shap.append([file, algo, rate, description, feature, category, mean_features]) return result_shap @@ -426,7 +426,7 @@ def shap_explainer(ground_truth, algorithm="cdrec", params=None, contamination=" categories, features = Explainer.load_configuration() for current_series in range(0, limitation): - print("Generation ", current_series, "____________________________________________________________________") + print("Generation ", current_series, "___________________________________________________________________") print("\tContamination ", current_series, "...") if contamination == "mcar": diff --git a/tests/__pycache__/test_contamination_mcar.cpython-312.pyc b/tests/__pycache__/test_contamination_mcar.cpython-312.pyc index a94daa95b79002026574fa38302818e95fbcbbde..afa3b5a76f9066eb7efa435463e66668cda24ce2 100644 GIT binary patch delta 176 zcmeCS>#*ZJ&CAQh00h>8U(&8@yjd2U88x> z3&(eGJmBEz=kMg7!!)0LCi{H8nS2W>E=w3($c-P&CAQh00b;CZ_+Mo5ALnwyHTcn=| zFcp~qxkaWxqL>LtOx`Tx!`LuMHqeMesGq-+e+K^wu`RNji}--n003N5G3)>U diff --git a/tests/__pycache__/test_explainer.cpython-312.pyc b/tests/__pycache__/test_explainer.cpython-312.pyc index 7f9e2f5790b950d3789061f7b3aa2a5bd71502ea..7b32e2fda65a23fb74d306bcb6ac52e4cd151a58 100644 GIT binary patch delta 456 zcmcblvP6aVG%qg~0}$L*`;sQLk=K)jQEGEIi$ANT@LcxQAPFF7V)SB2;i(ar&5$B8 zml+|-!cfVfDZ2SQ=PpJrZgqv+(&7?@%)H49xs#Zg)D<=}@pv+FawI2~q^9Q=RZb4& z)fVPSttg0xiQEG7CU^6ev#S7YDZarpnUl|u`3axIWP83&4K{`##%RV1(P@_h(?2sv zaYiv-p z5OP64^O8XM7ONe(XOgc3N1TYcY#epTxa>-N;`R9Si}C3fGDM?icmkFY9@JVqoEwx+tJEfoDd}4PmJnIhTaBZb&OHsJSFEXI?Q1ux1e furj7Iei1NbjAESNaa~09qKN8O1`w}^52y|RM#7lw delta 502 zcmZ3Ya!G~vG%qg~0}u$SeM#fp$m_|%D7iVD#h+D^e=hrKkOUAkF?uni@YD#*W=Ij9 z%Zw0ZVW{NL6xn>9a~Gq8q`E?JMt*5dib7JVLS~*qa$-qpdVW#mbHG&G;+WjQTh6Www5XVwc`_THA?p`527bxOR(zcrTnxdCF^m^t(k};P zd}fg3jAFdV76u;{=l#A~UL2xNnF&z&rUh-`;x96(JX746pDQ-QW@E_wV$d z;k!ZPf{f`W1_oIdrae_Bbgx9kUWre+5Se<}J`E`9!060$B|GKm0VH9Y2$B#?WO^VFxqx$p`U=SlMiCd}BCiNUel(Vx zd{@AWPZ1Ps96+Ll;WLZrWK+S5a&oMU>5N|l%o(E?CwN>JQN1Xl`jr91E8+vH0{|f5 BntlKP diff --git a/tests/test_explainer.py b/tests/test_explainer.py index be7ebb9..f698d9d 100644 --- a/tests/test_explainer.py +++ b/tests/test_explainer.py @@ -22,6 +22,7 @@ def resolve_path(local_path, github_actions_path): else: raise FileNotFoundError("File not found in both: ", local_path, " and ", github_actions_path) + def get_file_path(set_name="test"): """ Find the accurate path for loading files of tests @@ -46,7 +47,6 @@ def test_explainer_shap(self): expected_categories, expected_features = Explainer.load_configuration() - gap = TimeSeries(data=get_file_path(filename)) shap_values, shap_details = Explainer.shap_explainer(ground_truth=gap.ts, file_name=filename, use_seed=True, @@ -58,9 +58,7 @@ def test_explainer_shap(self): for i, (_, output) in enumerate(shap_details): assert np.isclose(RMSE[i], output, atol=0.01) - - - for i, (x, algo, rate, description, feature, categorie, mean_features) in enumerate(shap_values): + for i, (x, algo, rate, description, feature, category, mean_features) in enumerate(shap_values): assert np.isclose(SHAP_VAL[i], rate, atol=0.01) self.assertTrue(x is not None and not (isinstance(x, (int, float)) and np.isnan(x))) @@ -68,14 +66,15 @@ def test_explainer_shap(self): self.assertTrue(rate is not None and not (isinstance(rate, (int, float)) and np.isnan(rate))) self.assertTrue(description is not None) self.assertTrue(feature is not None) - self.assertTrue(categorie is not None) - self.assertTrue(mean_features is not None and not (isinstance(mean_features, (int, float)) and np.isnan(mean_features))) + self.assertTrue(category is not None) + self.assertTrue( + mean_features is not None and not (isinstance(mean_features, (int, float)) and np.isnan(mean_features))) # Check relation feature/category feature_found_in_category = False - for category, features in expected_categories.items(): - if feature in features: - assert categorie == category, f"Feature '{feature}' should be in category '{category}', but is in '{categorie}'" + for exp_category, exp_features in expected_categories.items(): + if feature in exp_features: + assert category == exp_category, f"Feature '{feature}' must in '{exp_category}', but is in '{category}'" feature_found_in_category = True break assert feature_found_in_category, f"Feature '{feature}' not found in any category" @@ -83,8 +82,6 @@ def test_explainer_shap(self): # Check relation description/feature if feature in expected_features: expected_description = expected_features[feature] - assert description == expected_description, f"Feature '{feature}' has wrong description. Expected '{expected_description}', got '{description}'" + assert description == expected_description, f"Feature '{feature}' has wrong description. Expected '{expected_description}', got '{description}' " else: assert False, f"Feature '{feature}' not found in the FEATURES dictionary" - -