Skip to content

Commit

Permalink
Fix bugs of hetero-nn
Browse files Browse the repository at this point in the history
Signed-off-by: weijingchen <talkingwallace@sohu.com>

Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Dec 4, 2023
1 parent 337a8a1 commit 5a64ef6
Show file tree
Hide file tree
Showing 11 changed files with 33 additions and 24 deletions.
6 changes: 3 additions & 3 deletions examples/pipeline/hetero_nn/test_nn_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def main(config="../../config.yaml", namespace=""):
num_train_epochs=5,
per_device_train_batch_size=16,
logging_strategy='epoch',
no_cuda=True,
no_cuda=True
)

guest_conf = get_config_of_default_runner(
Expand All @@ -70,7 +70,7 @@ def main(config="../../config.yaml", namespace=""):

hetero_nn_0 = HeteroNN(
'hetero_nn_0',
train_data=psi_0.outputs['output_data']
train_data=psi_0.outputs['output_data'], validate_data=psi_0.outputs['output_data']
)

hetero_nn_0.guest.task_setting(runner_conf=guest_conf)
Expand All @@ -86,7 +86,7 @@ def main(config="../../config.yaml", namespace=""):
'eval_0',
runtime_roles=['guest'],
metrics=['auc'],
input_data=[hetero_nn_1.outputs['predict_data_output'], hetero_nn_0.outputs['train_data_output']]
input_data=[hetero_nn_0.outputs['train_data_output']]
)

pipeline.add_task(psi_0)
Expand Down
5 changes: 2 additions & 3 deletions examples/pipeline/hetero_nn/test_nn_binary_with_fedpass.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,10 @@ def main(config="../../config.yaml", namespace=""):
namespace="experiment"))

training_args = TrainingArguments(
num_train_epochs=5,
num_train_epochs=1,
per_device_train_batch_size=16,
logging_strategy='epoch',
no_cuda=True,
disable_tqdm=False
no_cuda=True
)

guest_conf = get_config_of_default_runner(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def main(config="../config.yaml", namespace=""):
pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter)

psi_0 = PSI("psi_0")
psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest",
psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="egc_hetero_guest",
namespace="experiment"))
psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host",
psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="egc_hetero_host",
namespace="experiment"))

hetero_sbt_0 = HeteroSecureBoost('sbt_0', num_trees=3, max_bin=32, max_depth=3, goss=True,
hetero_sbt_0 = HeteroSecureBoost('sbt_0', num_trees=2, max_bin=32, max_depth=3, goss=True, top_rate=0.2,
he_param={'kind': 'paillier', 'key_length': 1024}, train_data=psi_0.outputs['output_data'],)
evaluation_0 = Evaluation(
'eval_0',
Expand Down
2 changes: 1 addition & 1 deletion python/fate/components/components/hetero_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def train(
train_model_output,
train_model_input
)

logger.info('cwj done')

@hetero_nn.predict()
def predict(
Expand Down
8 changes: 6 additions & 2 deletions python/fate/components/components/nn/component_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,17 @@ def train_procedure(

logger.info('Predicting Train & Validate Data')
train_pred = runner.predict(train_data_, saved_model_path)
validate_pred = None
if validate_data_ is not None:
validate_pred = runner.predict(validate_data_)

logger.info('predicting done')
if train_pred is not None:
assert isinstance(
train_pred, DataFrame), "train predict result should be a DataFrame"
add_dataset_type(train_pred, consts.TRAIN_SET)

if validate_data_ is not None:
validate_pred = runner.predict(validate_data_)
if validate_pred is not None:
assert isinstance(
validate_pred, DataFrame), "validate predict result should be a DataFrame"
add_dataset_type(validate_pred, consts.VALIDATE_SET)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def predict(self,
output_dir: str = None,
saved_model_path: str = None) -> DataFrame:

logger.info('cwj pred called')
test_set = self._prepare_data(test_data, 'test_data')
if self.trainer is not None:
trainer = self.trainer
Expand Down Expand Up @@ -271,8 +272,9 @@ def predict(self,
dataframe_format='fate_std',
task_type=self.task_type,
classes=classes)

logger.info('cwj pred end guest')
return rs_df

elif self.is_host():
trainer.predict(test_set)
logger.info('cwj pred end host')
3 changes: 2 additions & 1 deletion python/fate/ml/ensemble/algo/secureboost/hetero/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def _extend_score(s: pd.Series, class_num, dim):
acc_scores = acc_scores + extend_scores * learning_rate
return acc_scores


class HeteroSecureBoostGuest(HeteroBoostingTree):
def __init__(
self,
Expand Down Expand Up @@ -404,7 +405,7 @@ def from_model(self, model: dict):
self._init_score = float(model["init_score"]) if model["init_score"] is not None else None
# initialize
self._tree_dim = self.num_class if self.objective == MULTI_CE else 1
self._loss_func = _get_loss_func(self.objective)
self._loss_func = _get_loss_func(self.objective, class_num=self.num_class)
# for warmstart
self._model_loaded = True
# load loss
Expand Down
6 changes: 3 additions & 3 deletions python/fate/ml/glm/homo/lr/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fate.arch import Context
import logging
import torch as t
from fate.ml.nn.homo import FedAVGCLient, TrainingArguments, FedAVGArguments
from fate.ml.nn.homo.fedavg import FedAVGClient, TrainingArguments, FedAVGArguments
from transformers import default_data_collator
import functools
import tempfile
Expand Down Expand Up @@ -408,7 +408,7 @@ def fit(self, ctx: Context, train_data: DataFrame,
num_train_epochs=self.max_iter,
per_device_train_batch_size=self.batch_size,
per_device_eval_batch_size=self.batch_size)
self.trainer = FedAVGCLient(
self.trainer = FedAVGClient(
ctx,
model=self.model,
loss_fn=loss_fn,
Expand Down Expand Up @@ -436,7 +436,7 @@ def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame:
train_arg = TrainingArguments(
num_train_epochs=self.max_iter,
per_device_eval_batch_size=batch_size)
trainer = FedAVGCLient(
trainer = FedAVGClient(
ctx,
train_set=self.predict_set,
model=self.model,
Expand Down
4 changes: 2 additions & 2 deletions python/fate/ml/glm/homo/lr/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from fate.arch.dataframe import DataFrame
from fate.arch import Context
import logging
from fate.ml.nn.homo import FedAVGServer
from fate.ml.nn.homo.fedavg import FedAVGServer


logger = logging.getLogger(__name__)
Expand All @@ -25,4 +25,4 @@ def predict(
ctx: Context,
predict_data: DataFrame = None) -> DataFrame:

logger.info('kkip prediction stage')
logger.info('skip prediction stage')
11 changes: 7 additions & 4 deletions python/fate/ml/nn/model_zoo/agg_layer/fedpass/_passport_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ def set_key(self, skey, bkey):
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):

if 'skey' in state_dict:
self.register_buffer('skey', t.randn(*state_dict['skey'].size()))
if 'bkey' in state_dict:
self.register_buffer('bkey', t.randn(*state_dict['bkey'].size()))
skey = '_agg_layer._model.skey'
bkey = '_agg_layer._model.bkey'

if skey in state_dict:
self.register_buffer('skey', t.randn(*state_dict[skey].size()))
if bkey in state_dict:
self.register_buffer('bkey', t.randn(*state_dict[bkey].size()))

if '_out_scale' in state_dict:
self.scale = nn.Parameter(t.randn(*state_dict['_out_scale'].size()))
Expand Down

0 comments on commit 5a64ef6

Please sign in to comment.