Skip to content

Commit

Permalink
Signed-off-by: weijingchen <talkingwallace@sohu.com>
Browse files Browse the repository at this point in the history
Fix examples & fix sbt label check

Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Dec 20, 2023
1 parent 0e7463c commit 8cecc32
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 36 deletions.
9 changes: 6 additions & 3 deletions python/fate/ml/ensemble/algo/secureboost/hetero/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,19 @@ def _init_sample_scores(self, ctx: Context, label, train_data: DataFrame):
self._accumulate_scores = self._loss_func.initialize(label)

def _check_label(self, label: DataFrame):
label_df = label.as_pd_df()[label.schema.label_name]

train_data_binarized_label = label.get_dummies()
labels = [int(label_name.split("_")[1]) for label_name in train_data_binarized_label.columns]
label_set = set(labels)


if self.objective == MULTI_CE:

if self.num_class is None or self.num_class <= 2:
raise ValueError(
f"num_class should be set and greater than 2 for multi:ce objective, but got {self.num_class}"
)

label_set = set(np.unique(label_df))
if len(label_set) > self.num_class:
raise ValueError(
f"num_class should be greater than or equal to the number of unique label in provided train data, but got {self.num_class} and {len(label_set)}"
Expand All @@ -215,7 +219,6 @@ def _check_label(self, label: DataFrame):
)

elif self.objective == BINARY_BCE:
label_set = set(np.unique(label_df))
assert len(label_set) == 2, f"binary classification task should have 2 unique label, but got {label_set}"
assert (
0 in label_set and 1 in label_set
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def create_ctx(local):
console_handler.setFormatter(formatter)

logger.addHandler(console_handler)
computing = CSession()
computing = CSession(data_dir='./')
return Context(
computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter])
)
Expand Down
60 changes: 28 additions & 32 deletions python/fate/ml/nn/test/test_hetero_nn_sshe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def get_current_datetime_str():

def create_ctx(local, context_name):
from fate.arch import Context
from fate.arch.computing.standalone import CSession
from fate.arch.federation.standalone import StandaloneFederation
from fate.arch.computing.backends.standalone import CSession
from fate.arch.federation.backends.standalone import StandaloneFederation
import logging

# prepare log
Expand Down Expand Up @@ -49,20 +49,27 @@ def set_seed(seed):
t.backends.cudnn.deterministic = True
t.backends.cudnn.benchmark = False

set_seed(42)

batch_size = 64
epoch = 10
guest_bottom = t.nn.Linear(10, 4)
guest_bottom = t.nn.Linear(10, 10)
guest_top = t.nn.Sequential(
t.nn.Linear(4, 1),
t.nn.Linear(10, 1),
t.nn.Sigmoid()
)
host_bottom = t.nn.Linear(20, 4)
host_bottom = t.nn.Linear(20, 10)

# # make random fake data
sample_num = 569

args = TrainingArguments(
num_train_epochs=1,
per_device_train_batch_size=256,
logging_strategy='epoch',
no_cuda=True,
log_level='debug',
disable_tqdm=False
)

if party == "guest":

from fate.ml.evaluation.metric_base import MetricEnsemble
Expand All @@ -80,21 +87,13 @@ def set_seed(seed):
top_model=guest_top,
bottom_model=guest_bottom,
agglayer_arg=SSHEArgument(
guest_in_features=4,
host_in_features=4,
out_features=4,
layer_lr=0.01
guest_in_features=10,
host_in_features=10,
out_features=10
)
)
model
optimizer = t.optim.Adam(model.parameters(), lr=0.01)

args = TrainingArguments(
num_train_epochs=5,
per_device_train_batch_size=16,
no_cuda=True,
disable_tqdm=False
)
trainer = HeteroNNTrainerGuest(
ctx=ctx,
model=model,
Expand All @@ -106,10 +105,15 @@ def set_seed(seed):
compute_metrics=MetricEnsemble().add_metric(MultiAccuracy())
)
trainer.train()
pred = trainer.predict(dataset)
pred_0 = trainer.predict(dataset)
# # compute auc
from sklearn.metrics import roc_auc_score
print(roc_auc_score(pred.label_ids, pred.predictions))
print(roc_auc_score(pred_0.label_ids, pred_0.predictions))

pred_1 = trainer.predict(dataset)
# # compute auc
from sklearn.metrics import roc_auc_score
print(roc_auc_score(pred_1.label_ids, pred_1.predictions))

elif party == "host":

Expand All @@ -123,21 +127,13 @@ def set_seed(seed):
model = HeteroNNModelHost(
bottom_model=host_bottom,
agglayer_arg=SSHEArgument(
guest_in_features=4,
host_in_features=4,
out_features=4,
layer_lr=0.01
guest_in_features=10,
host_in_features=10,
out_features=10
)
)
optimizer = t.optim.Adam(model.parameters(), lr=0.01)

args = TrainingArguments(
num_train_epochs=5,
per_device_train_batch_size=16,
no_cuda=True,
disable_tqdm=False
)

trainer = HeteroNNTrainerHost(
ctx=ctx,
model=model,
Expand All @@ -148,4 +144,4 @@ def set_seed(seed):
)
trainer.train()
trainer.predict(dataset)

trainer.predict(dataset)

0 comments on commit 8cecc32

Please sign in to comment.