Skip to content

Commit

Permalink
fix : __compare_weight_handler__ function added #347
Browse files Browse the repository at this point in the history
  • Loading branch information
sepandhaghighi committed Oct 10, 2021
1 parent 95d6939 commit 80457f1
Showing 1 changed file with 32 additions and 11 deletions.
43 changes: 32 additions & 11 deletions pycm/pycm_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,35 @@ def __compare_sort_handler__(compare):
return (max_overall_name, max_class_name)


def __compare_weight_handler__(compare, weight, weight_type):
"""
Handle different weights validation.
:param compare: Compare
:type compare : pycm.Compare object
:param weight: input weight
:type weight: dict
:param weight_type: input weight type
:type weight_type: str
:return: None
"""
valid_dict = {"class_weight":compare.classes,"class_benchmark_weight":CLASS_BENCHMARK_SCORE_DICT.keys(),"overall_benchmark_weight":OVERALL_BENCHMARK_SCORE_DICT.keys()}
error_dict = {"class_weight":COMPARE_CLASS_WEIGHT_ERROR, "class_benchmark_weight":COMPARE_CLASS_BENCHMARK_WEIGHT_ERROR, "overall_benchmark_weight":COMPARE_OVERALL_BENCHMARK_WEIGHT_ERROR}
warning_dict = {"class_weight":COMPARE_CLASS_WEIGHT_WARNING, "class_benchmark_weight":COMPARE_CLASS_BENCHMARK_WEIGHT_WARNING,"overall_benchmark_weight":COMPARE_OVERALL_BENCHMARK_WEIGHT_WARNING}
if weight is not None:
if not isinstance(weight, dict):
raise pycmCompareError(error_dict[weight_type])
if set(weight.keys()) == set(valid_dict[weight_type]):
if all([isfloat(x) for x in weight.values()]
) and sum(weight.values()) != 0:
setattr(compare,weight_type,weight)
else:
warn(warning_dict[weight_type], RuntimeWarning)
else:
raise pycmCompareError(error_dict[weight_type])



def __compare_assign_handler__(compare, cm_dict, class_weight, class_benchmark_weight, overall_benchmark_weight, digit):
"""
Assign basic parameters to Compare.
Expand Down Expand Up @@ -260,14 +289,6 @@ def __compare_assign_handler__(compare, cm_dict, class_weight, class_benchmark_w
compare.sorted = None
compare.scores = {k: {"overall": 0, "class": 0}.copy()
for k in cm_dict.keys()}
if class_weight is not None:
if not isinstance(class_weight, dict):
raise pycmCompareError(COMPARE_CLASS_WEIGHT_ERROR)
if set(class_weight.keys()) == set(compare.classes):
if all([isfloat(x) for x in class_weight.values()]
) and sum(class_weight.values()) != 0:
compare.class_weight = class_weight
else:
warn(COMPARE_CLASS_WEIGHT_WARNING, RuntimeWarning)
else:
raise pycmCompareError(COMPARE_CLASS_WEIGHT_ERROR)
__compare_weight_handler__(compare,class_weight,"class_weight")
__compare_weight_handler__(compare,class_benchmark_weight,"class_benchmark_weight")
__compare_weight_handler__(compare,overall_benchmark_weight,"overall_benchmark_weight")

0 comments on commit 80457f1

Please sign in to comment.