From 94cf7a92caa5f2bcbd42600a89792e069002c590 Mon Sep 17 00:00:00 2001 From: weijingchen Date: Thu, 21 Dec 2023 16:31:39 +0800 Subject: [PATCH] Signed-off-by: weijingchen Fix reg label check Signed-off-by: cwj --- .../ensemble/algo/secureboost/hetero/guest.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py b/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py index b9f22891c0..7b787abcc1 100644 --- a/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py +++ b/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py @@ -197,33 +197,33 @@ def _init_sample_scores(self, ctx: Context, label, train_data: DataFrame): def _check_label(self, label: DataFrame): - train_data_binarized_label = label.get_dummies() - labels = [int(label_name.split("_")[1]) for label_name in train_data_binarized_label.columns] - label_set = set(labels) - - if self.objective == MULTI_CE: - - if self.num_class is None or self.num_class <= 2: - raise ValueError( - f"num_class should be set and greater than 2 for multi:ce objective, but got {self.num_class}" - ) - - if len(label_set) > self.num_class: - raise ValueError( - f"num_class should be greater than or equal to the number of unique label in provided train data, but got {self.num_class} and {len(label_set)}" - ) - if max(label_set) - 1 > self.num_class: - raise ValueError( - f"the max label index in the provided train data should be less than or equal to num_class - 1, but got index {max(label_set)} which is > {self.num_class}" - ) - - elif self.objective == BINARY_BCE: - assert len(label_set) == 2, f"binary classification task should have 2 unique label, but got {label_set}" - assert ( - 0 in label_set and 1 in label_set - ), f"binary classification task should have label 0 and 1, but got {label_set}" - self.num_class = 2 + if self.objective != REGRESSION: + train_data_binarized_label = label.get_dummies() + labels = [int(label_name.split("_")[1]) for label_name in train_data_binarized_label.columns] + label_set = set(labels) + if self.objective == MULTI_CE: + + if self.num_class is None or self.num_class <= 2: + raise ValueError( + f"num_class should be set and greater than 2 for multi:ce objective, but got {self.num_class}" + ) + + if len(label_set) > self.num_class: + raise ValueError( + f"num_class should be greater than or equal to the number of unique label in provided train data, but got {self.num_class} and {len(label_set)}" + ) + if max(label_set) - 1 > self.num_class: + raise ValueError( + f"the max label index in the provided train data should be less than or equal to num_class - 1, but got index {max(label_set)} which is > {self.num_class}" + ) + + elif self.objective == BINARY_BCE: + assert len(label_set) == 2, f"binary classification task should have 2 unique label, but got {label_set}" + assert ( + 0 in label_set and 1 in label_set + ), f"binary classification task should have label 0 and 1, but got {label_set}" + self.num_class = 2 else: self.num_class = None