Skip to content

Commit

Permalink
correction of the shap explainer and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
qnater committed Sep 25, 2024
1 parent ee6b97b commit 0dac9fe
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 71 deletions.
56 changes: 3 additions & 53 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified imputegap/explainer/__pycache__/explainer.cpython-312.pyc
Binary file not shown.
12 changes: 6 additions & 6 deletions imputegap/explainer/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def print(shap_values, shap_details=None):
print(f"\tRMSE SERIES {i:<5} : {output:<15}")

print("\n\nSHAP Results details : ")
for (x, algo, rate, description, feature, categorie, mean_features) in shap_values:
print(f"\tFeature : {x:<5} {algo:<10} with a score of {rate:<10} {categorie:<18} {description:<75} {feature}\n")
for (x, algo, rate, description, feature, category, mean_features) in shap_values:
print(f"\tFeature : {x:<5} {algo:<10} with a score of {rate:<10} {category:<18} {description:<75} {feature}\n")

def convert_results(tmp, file, algo, descriptions, features, categories, mean_features, to_save):
"""
Expand Down Expand Up @@ -145,9 +145,9 @@ def convert_results(tmp, file, algo, descriptions, features, categories, mean_fe
print(tup[2], end=",")

with open(to_save + "_results.txt", 'w') as file_output:
for (x, algo, rate, description, feature, categorie, mean_features) in result_display:
file_output.write(f"Feature : {x:<5} {algo:<10} with a score of {rate:<10} {categorie:<18} {description:<65} {feature}\n")
result_shap.append([file, algo, rate, description, feature, categorie, mean_features])
for (x, algo, rate, description, feature, category, mean_features) in result_display:
file_output.write(f"Feature : {x:<5} {algo:<10} with a score of {rate:<10} {category:<18} {description:<65} {feature}\n")
result_shap.append([file, algo, rate, description, feature, category, mean_features])

return result_shap

Expand Down Expand Up @@ -426,7 +426,7 @@ def shap_explainer(ground_truth, algorithm="cdrec", params=None, contamination="
categories, features = Explainer.load_configuration()

for current_series in range(0, limitation):
print("Generation ", current_series, "____________________________________________________________________")
print("Generation ", current_series, "___________________________________________________________________")
print("\tContamination ", current_series, "...")

if contamination == "mcar":
Expand Down
Binary file modified tests/__pycache__/test_contamination_mcar.cpython-312.pyc
Binary file not shown.
Binary file modified tests/__pycache__/test_explainer.cpython-312.pyc
Binary file not shown.
21 changes: 9 additions & 12 deletions tests/test_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def resolve_path(local_path, github_actions_path):
else:
raise FileNotFoundError("File not found in both: ", local_path, " and ", github_actions_path)


def get_file_path(set_name="test"):
"""
Find the accurate path for loading files of tests
Expand All @@ -46,7 +47,6 @@ def test_explainer_shap(self):

expected_categories, expected_features = Explainer.load_configuration()


gap = TimeSeries(data=get_file_path(filename))

shap_values, shap_details = Explainer.shap_explainer(ground_truth=gap.ts, file_name=filename, use_seed=True,
Expand All @@ -58,33 +58,30 @@ def test_explainer_shap(self):
for i, (_, output) in enumerate(shap_details):
assert np.isclose(RMSE[i], output, atol=0.01)



for i, (x, algo, rate, description, feature, categorie, mean_features) in enumerate(shap_values):
for i, (x, algo, rate, description, feature, category, mean_features) in enumerate(shap_values):
assert np.isclose(SHAP_VAL[i], rate, atol=0.01)

self.assertTrue(x is not None and not (isinstance(x, (int, float)) and np.isnan(x)))
self.assertTrue(algo is not None)
self.assertTrue(rate is not None and not (isinstance(rate, (int, float)) and np.isnan(rate)))
self.assertTrue(description is not None)
self.assertTrue(feature is not None)
self.assertTrue(categorie is not None)
self.assertTrue(mean_features is not None and not (isinstance(mean_features, (int, float)) and np.isnan(mean_features)))
self.assertTrue(category is not None)
self.assertTrue(
mean_features is not None and not (isinstance(mean_features, (int, float)) and np.isnan(mean_features)))

# Check relation feature/category
feature_found_in_category = False
for category, features in expected_categories.items():
if feature in features:
assert categorie == category, f"Feature '{feature}' should be in category '{category}', but is in '{categorie}'"
for exp_category, exp_features in expected_categories.items():
if feature in exp_features:
assert category == exp_category, f"Feature '{feature}' must in '{exp_category}', but is in '{category}'"
feature_found_in_category = True
break
assert feature_found_in_category, f"Feature '{feature}' not found in any category"

# Check relation description/feature
if feature in expected_features:
expected_description = expected_features[feature]
assert description == expected_description, f"Feature '{feature}' has wrong description. Expected '{expected_description}', got '{description}'"
assert description == expected_description, f"Feature '{feature}' has wrong description. Expected '{expected_description}', got '{description}' "
else:
assert False, f"Feature '{feature}' not found in the FEATURES dictionary"


0 comments on commit 0dac9fe

Please sign in to comment.