Skip to content

Commit

Permalink
Merge pull request #5363 from FederatedAI/dev-2.0.0-rc-final-round-fix
Browse files Browse the repository at this point in the history
Dev 2.0.0 rc final round fix
  • Loading branch information
mgqa34 authored Dec 21, 2023
2 parents 108093e + 94cf7a9 commit 19ee92c
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions python/fate/ml/ensemble/algo/secureboost/hetero/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,33 +197,33 @@ def _init_sample_scores(self, ctx: Context, label, train_data: DataFrame):

def _check_label(self, label: DataFrame):

train_data_binarized_label = label.get_dummies()
labels = [int(label_name.split("_")[1]) for label_name in train_data_binarized_label.columns]
label_set = set(labels)


if self.objective == MULTI_CE:

if self.num_class is None or self.num_class <= 2:
raise ValueError(
f"num_class should be set and greater than 2 for multi:ce objective, but got {self.num_class}"
)

if len(label_set) > self.num_class:
raise ValueError(
f"num_class should be greater than or equal to the number of unique label in provided train data, but got {self.num_class} and {len(label_set)}"
)
if max(label_set) - 1 > self.num_class:
raise ValueError(
f"the max label index in the provided train data should be less than or equal to num_class - 1, but got index {max(label_set)} which is > {self.num_class}"
)

elif self.objective == BINARY_BCE:
assert len(label_set) == 2, f"binary classification task should have 2 unique label, but got {label_set}"
assert (
0 in label_set and 1 in label_set
), f"binary classification task should have label 0 and 1, but got {label_set}"
self.num_class = 2
if self.objective != REGRESSION:
train_data_binarized_label = label.get_dummies()
labels = [int(label_name.split("_")[1]) for label_name in train_data_binarized_label.columns]
label_set = set(labels)
if self.objective == MULTI_CE:

if self.num_class is None or self.num_class <= 2:
raise ValueError(
f"num_class should be set and greater than 2 for multi:ce objective, but got {self.num_class}"
)

if len(label_set) > self.num_class:
raise ValueError(
f"num_class should be greater than or equal to the number of unique label in provided train data, but got {self.num_class} and {len(label_set)}"
)
if max(label_set) - 1 > self.num_class:
raise ValueError(
f"the max label index in the provided train data should be less than or equal to num_class - 1, but got index {max(label_set)} which is > {self.num_class}"
)

elif self.objective == BINARY_BCE:
assert len(label_set) == 2, f"binary classification task should have 2 unique label, but got {label_set}"
assert (
0 in label_set and 1 in label_set
), f"binary classification task should have label 0 and 1, but got {label_set}"
self.num_class = 2
else:
self.num_class = None

Expand Down

0 comments on commit 19ee92c

Please sign in to comment.