Skip to content

Commit

Permalink
fix : compare weight renamed to class_weight #347
Browse files Browse the repository at this point in the history
  • Loading branch information
sepandhaghighi authored and alirezazolanvari committed Nov 28, 2021
1 parent cd8ce57 commit 665c85c
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions pycm/pycm_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,23 @@ class Compare():
'cm1'
"""

def __init__(self, cm_dict, by_class=False, weight=None, digit=5):
def __init__(self, cm_dict, by_class=False, class_weight=None, digit=5):
"""
Init method.
:param cm_dict: cm's dictionary
:type cm_dict : dict
:param by_class: compare by class flag
:type by_class: bool
:param weight: class weights
:type weight: dict
:param class_weight: class weights
:type class_weight: dict
:param digit: precision digit (default value : 5)
:type digit : int
"""
self.scores = None
self.sorted = None
self.classes = None
__compare_assign_handler__(self, cm_dict, weight, digit)
__compare_assign_handler__(self, cm_dict, class_weight, digit)
__compare_class_handler__(self, cm_dict)
__compare_overall_handler__(self, cm_dict)
__compare_rounder__(self, cm_dict)
Expand Down Expand Up @@ -143,7 +143,7 @@ def __compare_class_handler__(compare, cm_dict):
cm.class_stat[item][c]] for cm in cm_dict.values()]
if all([isinstance(x, int) for x in all_class_score]):
for cm_name in cm_dict.keys():
compare.scores[cm_name]["class"] += compare.weight[c] * (
compare.scores[cm_name]["class"] += compare.class_weight[c] * (
CLASS_BENCHMARK_SCORE_DICT[item][cm_dict[cm_name].class_stat[item][c]] / max_item_score)


Expand Down Expand Up @@ -211,7 +211,7 @@ def __compare_sort_handler__(compare):
return (max_overall_name, max_class_name)


def __compare_assign_handler__(compare, cm_dict, weight, digit):
def __compare_assign_handler__(compare, cm_dict, class_weight, digit):
"""
Assign basic parameters to Comapre.
Expand All @@ -221,8 +221,8 @@ def __compare_assign_handler__(compare, cm_dict, weight, digit):
:type cm_dict : dict
:param digit: precision digit (default value : 5)
:type digit : int
:param weight: class weights
:type weight: dict
:param class_weight: class weights
:type class_weight: dict
:return: None
"""
if not isinstance(cm_dict, dict):
Expand All @@ -236,18 +236,18 @@ def __compare_assign_handler__(compare, cm_dict, weight, digit):
if len(cm_dict) < 2:
raise pycmCompareError(COMPARE_NUMBER_ERROR)
compare.classes = list(cm_dict.values())[0].classes
compare.weight = {k: 1 for k in compare.classes}
compare.class_weight = {k: 1 for k in compare.classes}
compare.digit = digit
compare.best = None
compare.best_name = None
compare.sorted = None
compare.scores = {k: {"overall": 0, "class": 0}.copy()
for k in cm_dict.keys()}
if weight is not None:
if not isinstance(weight, dict):
if class_weight is not None:
if not isinstance(class_weight, dict):
raise pycmCompareError(COMPARE_WEIGHT_ERROR)
if set(weight.keys()) == set(compare.classes) and all(
[isfloat(x) for x in weight.values()]):
compare.weight = weight
if set(class_weight.keys()) == set(compare.classes) and all(
[isfloat(x) for x in class_weight.values()]):
compare.class_weight = class_weight
else:
raise pycmCompareError(COMPARE_WEIGHT_ERROR)

0 comments on commit 665c85c

Please sign in to comment.