Skip to content

Commit

Permalink
Update fedpass
Browse files Browse the repository at this point in the history
Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Nov 15, 2023
1 parent 5e74cff commit fa2704a
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 68 deletions.
8 changes: 2 additions & 6 deletions python/fate/ml/nn/hetero/hetero_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ def __init__(
training_args: TrainingArguments,
train_set: Dataset,
val_set: Dataset = None,
agg_layer_arguments: Union[StdAggLayerArgument, FedPassArgument, HESSArgument] = None,
top_model_arguments: TopModelArguments = None,
loss_fn: nn.Module = None,
optimizer = None,
data_collator: Callable = None,
Expand All @@ -38,6 +36,7 @@ def __init__(

assert isinstance(model, HeteroNNModelGuest), ('Model should be a HeteroNNModelGuest instance, '
'but got {}.').format(type(model))
model.setup(ctx=ctx)

super().__init__(
ctx=ctx,
Expand All @@ -54,7 +53,6 @@ def __init__(
compute_metrics=compute_metrics
)

model.setup(ctx, agglayer_arg=agg_layer_arguments, top_arg=top_model_arguments)

def compute_loss(self, model, inputs, **kwargs):
# (features, labels), this format is used in FATE-1.x
Expand Down Expand Up @@ -120,7 +118,6 @@ def __init__(
training_args: TrainingArguments,
train_set: Dataset,
val_set: Dataset = None,
agg_layer_arguments: Union[StdAggLayerArgument, FedPassArgument, HESSArgument] = None,
optimizer=None,
data_collator: Callable = None,
scheduler=None,
Expand All @@ -130,7 +127,7 @@ def __init__(
):
assert isinstance(model, HeteroNNModelHost), ('Model should be a HeteroNNModelHost instance, '
'but got {}.').format(type(model))

model.setup(ctx=ctx)
super().__init__(
ctx=ctx,
model=model,
Expand All @@ -145,7 +142,6 @@ def __init__(
callbacks=callbacks,
compute_metrics=compute_metrics
)
model.setup(ctx, agglayer_arg=agg_layer_arguments)

def compute_loss(self, model, inputs, **kwargs):
# host side not computing loss
Expand Down
70 changes: 44 additions & 26 deletions python/fate/ml/nn/model_zoo/hetero_nn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,13 @@ class TopModelArguments(Args):

protect_strategy: Literal['fedpass'] = None
fed_pass_arg: FedPassArgument = None
add_output_layer: Literal[None, 'sigmoid', 'softmax'] = None

def __post_init__(self):
if self.protect_strategy == 'fedpass' and not isinstance(self.fed_pass_arg, FedPassArgument):
raise TypeError("fed_pass_arg must be an instance of FedPassArgument for protect_strategy 'fedpass'")
assert self.add_output_layer in [None, 'sigmoid', 'softmax'], \
"add_output_layer must be None, 'sigmoid' or 'softmax'"


def backward_loss(z, backward_error):
Expand All @@ -100,7 +103,10 @@ class HeteroNNModelGuest(HeteroNNModelBase):

def __init__(self,
top_model: t.nn.Module,
bottom_model: t.nn.Module = None
bottom_model: t.nn.Module = None,
agglayer_arg: Union[StdAggLayerArgument, FedPassArgument, HESSArgument] = None,
top_arg: TopModelArguments = None,
ctx: Context = None
):

super(HeteroNNModelGuest, self).__init__()
Expand All @@ -124,6 +130,7 @@ def __init__(self,
self._top_strategy = None
# top additional model
self._top_add_model = None
self.setup(ctx=ctx, agglayer_arg=agglayer_arg, top_arg=top_arg, bottom_arg=None)

def __repr__(self):
return (f"HeteroNNGuest(top_model={self._top_model}\n"
Expand All @@ -141,23 +148,30 @@ def setup(self, ctx:Context = None, agglayer_arg: Union[StdAggLayerArgument, Fed
top_arg: TopModelArguments = None, bottom_arg=None):

self._ctx = ctx
if agglayer_arg is None:
self._agg_layer = AggLayerGuest()
elif isinstance(agglayer_arg, StdAggLayerArgument):
self._agg_layer = AggLayerGuest(**agglayer_arg.to_dict())
elif isinstance(agglayer_arg, FedPassArgument):
self._agg_layer = FedPassAggLayerGuest(**agglayer_arg.to_dict())

if top_arg:
logger.info('detect top model strategy')
if top_arg.protect_strategy == 'fedpass':
fedpass_arg = top_arg.fed_pass_arg
top_fedpass_model = get_model(**fedpass_arg.to_dict())
self._top_add_model = top_fedpass_model
self._top_model = t.nn.Sequential(
self._top_model,
top_fedpass_model
)

if self._agg_layer is None:
if agglayer_arg is None:
self._agg_layer = AggLayerGuest()
elif isinstance(agglayer_arg, StdAggLayerArgument):
self._agg_layer = AggLayerGuest(**agglayer_arg.to_dict())
elif isinstance(agglayer_arg, FedPassArgument):
self._agg_layer = FedPassAggLayerGuest(**agglayer_arg.to_dict())

if self._top_add_model is None:
if top_arg:
logger.info('detect top model strategy')
if top_arg.protect_strategy == 'fedpass':
fedpass_arg = top_arg.fed_pass_arg
top_fedpass_model = get_model(**fedpass_arg.to_dict())
self._top_add_model = top_fedpass_model
self._top_model = t.nn.Sequential(
self._top_model,
top_fedpass_model
)
if top_arg.add_output_layer == 'sigmoid':
self._top_model.add_module('sigmoid', t.nn.Sigmoid())
elif top_arg.add_output_layer == 'softmax':
self._top_model.add_module('softmax', t.nn.Softmax(dim=1))

self._agg_layer.set_context(ctx)

Expand Down Expand Up @@ -218,7 +232,9 @@ def predict(self, x = None):
class HeteroNNModelHost(HeteroNNModelBase):

def __init__(self,
bottom_model: t.nn.Module
bottom_model: t.nn.Module,
agglayer_arg: Union[StdAggLayerArgument, FedPassArgument, HESSArgument] = None,
ctx: Context = None
):

super().__init__()
Expand All @@ -229,8 +245,8 @@ def __init__(self,
self._bottom_fw = None # for backward usage
# ctx
self._ctx = None

self._agg_layer = None
self.setup(ctx=ctx, agglayer_arg=agglayer_arg)

def __repr__(self):
return f"HeteroNNHost(bottom_model={self._bottom_model}, agg_layer={self._agg_layer})"
Expand All @@ -245,12 +261,14 @@ def setup(self, ctx:Context = None, agglayer_arg: Union[StdAggLayerArgument, Fed
bottom_arg=None):

self._ctx = ctx
if agglayer_arg is None:
self._agg_layer = AggLayerHost()
elif type(agglayer_arg) == StdAggLayerArgument:
self._agg_layer = AggLayerHost() # no parameters are needed
elif type(agglayer_arg) == FedPassArgument:
self._agg_layer = FedPassAggLayerHost(**agglayer_arg.to_dict())

if self._agg_layer is None:
if agglayer_arg is None:
self._agg_layer = AggLayerHost()
elif type(agglayer_arg) == StdAggLayerArgument:
self._agg_layer = AggLayerHost() # no parameters are needed
elif type(agglayer_arg) == FedPassArgument:
self._agg_layer = FedPassAggLayerHost(**agglayer_arg.to_dict())

self._agg_layer.set_context(ctx)

Expand Down
76 changes: 55 additions & 21 deletions python/fate/ml/nn/test/test_fedpass_lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,14 @@ def forward(self, x):

class LeNet_Top(nn.Module):

def __init__(self):
def __init__(self, out_feat=10):
super(LeNet_Top, self).__init__()
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc1act = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(120, 84)
self.fc2act = nn.ReLU(inplace=True)
self.fc3 = nn.Linear(84, 10)
self.fc3 = nn.Linear(84, out_feat)

def forward(self, x_a):
x = x_a
Expand Down Expand Up @@ -277,7 +277,7 @@ def __getitem__(self, item):
# optimizer.step()
# print(loss_sum / len(train_loader))

arg = TrainingArguments(num_train_epochs=1, per_device_train_batch_size=16, disable_tqdm=False,
arg = TrainingArguments(num_train_epochs=20, per_device_train_batch_size=16, disable_tqdm=False,
eval_steps=1,
evaluation_strategy='epoch'
)
Expand All @@ -286,11 +286,24 @@ def __getitem__(self, item):

from fate.ml.evaluation.metric_base import MetricEnsemble
from fate.ml.evaluation.classification import MultiAccuracy
ctx = create_ctx(guest, get_current_datetime_str())
from fate.ml.nn.model_zoo.hetero_nn_model import TopModelArguments, FedPassArgument

top_model = LeNet_Top()
ctx = create_ctx(guest, get_current_datetime_str())
top_model = LeNet_Top(out_feat=10)
model = HeteroNNModelGuest(
top_model=top_model
top_model=top_model,
# top_arg=TopModelArguments(
# protect_strategy='fedpass',
# fed_pass_arg=FedPassArgument(
# layer_type='linear',
# num_passport=64,
# in_channels_or_features=84,
# hidden_features=64,
# out_channels_or_features=10,
# passport_mode='single'
# )
# ),
ctx=ctx
)
loss = nn.CrossEntropyLoss()
optimizer = t.optim.Adam(model.parameters(), lr=0.01)
Expand All @@ -299,34 +312,55 @@ def __getitem__(self, item):
train_set=NoFeatureDataset(subset_train_data),
val_set=NoFeatureDataset(subset_val_data),
loss_fn=loss, optimizer=optimizer,
compute_metrics=MetricEnsemble().add_metric(MultiAccuracy())
compute_metrics=MetricEnsemble().add_metric(MultiAccuracy()),
)
trainer.train()



if party == 'host':

ctx = create_ctx(host, get_current_datetime_str())

bottom_model = LeNetBottom()
model = HeteroNNModelHost(
bottom_model=bottom_model
bottom_model=bottom_model,
agglayer_arg=FedPassArgument(
layer_type='conv',
in_channels_or_features=8,
out_channels_or_features=16,
kernel_size=(5, 5),
stride=(1, 1),
passport_mode='multi',
activation='relu',
num_passport=64
)
)
optimizer = t.optim.Adam(model.parameters(), lr=0.01)

trainer = HeteroNNTrainerHost(ctx, model, training_args=arg,
train_set=subset_train_data,
val_set=subset_val_data,
agg_layer_arguments=FedPassArgument(
layer_type='conv',
in_channels_or_features=8,
out_channels_or_features=16,
kernel_size=(5, 5),
stride=(1, 1),
passport_mode='multi',
activation='relu',
num_passport=64
),
optimizer=optimizer)
trainer.train()
trainer.train()

elif party == 'test':
from fate.ml.nn.model_zoo.hetero_nn_model import HeteroNNModelGuest, TopModelArguments, FedPassArgument

top_model = LeNet_Top(out_feat=84)
model = HeteroNNModelGuest(
top_model=top_model
)
model.setup(
top_arg=TopModelArguments(
protect_strategy='fedpass',
fed_pass_arg=FedPassArgument(
num_passport=64,
in_channels_or_features=84,
hidden_features=64,
out_channels_or_features=10,
passport_mode='single'
)
)
)



13 changes: 7 additions & 6 deletions python/fate/ml/nn/test/test_hetero_nn_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ def set_seed(seed):
loss_fn=loss_fn,
training_args=args,
)
trainer.train()
pred = trainer.predict(dataset)
# compute auc
from sklearn.metrics import roc_auc_score
print(roc_auc_score(pred.label_ids, pred.predictions))
# trainer.train()
# pred = trainer.predict(dataset)
# # compute auc
# from sklearn.metrics import roc_auc_score
# print(roc_auc_score(pred.label_ids, pred.predictions))

elif party == "host":

Expand Down Expand Up @@ -127,4 +127,5 @@ def set_seed(seed):
training_args=args
)
trainer.train()
trainer.predict(dataset)
trainer.predict(dataset)

9 changes: 0 additions & 9 deletions python/fate/ml/nn/trainer/trainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,13 +540,6 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control:
self._client_send_parameters(state, args, train_dataloader)


class FatePrinterCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_local_process_zero:
_ = logs.pop("total_flos", None)
logger.info(str(logs))


class CallbackWrapper(TrainerCallback):
def __init__(self, ctx: Context, wrapped_trainer: "HomoTrainerMixin"):
self.ctx = ctx
Expand Down Expand Up @@ -837,7 +830,6 @@ def _add_fate_callback(self, callback_handler):
continue
else:
new_callback_list.append(i)
new_callback_list.append(FatePrinterCallback())
callback_handler.callbacks = new_callback_list
callback_handler.callbacks.append(WrappedFedCallback(self.ctx, self))
callback_handler.callbacks.append(
Expand Down Expand Up @@ -935,7 +927,6 @@ def __init__(self,
continue
else:
new_callback_list.append(i)
new_callback_list.append(FatePrinterCallback())
self.callback_handler.callbacks = new_callback_list


Expand Down

0 comments on commit fa2704a

Please sign in to comment.