Skip to content

Commit

Permalink
Update Homo LR
Browse files Browse the repository at this point in the history
Signed-off-by: cwj <talkingwallace@sohu.com>
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
talkingwallace authored and sagewe committed Jul 21, 2023
1 parent 0a3b93c commit 7bd3c59
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 20 deletions.
2 changes: 1 addition & 1 deletion python/fate/ml/glm/homo/lr/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def fit(self, ctx: Context, train_data: DataFrame,
self.trainer.set_local_mode()
self.trainer.train()

logger.info('training finished')
logger.info('homo lr fit done')

def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame:

Expand Down
21 changes: 18 additions & 3 deletions python/fate/ml/glm/homo/lr/test/test_fed_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,17 @@ def create_ctx(local):
dtype="object")

data = reader.to_frame(ctx, df)
client = HomoLRClient(
50, 800, optimizer_param={
'method': 'adam', 'penalty': 'l1', 'aplha': 0.1, 'optimizer_para': {
'lr': 0.1}}, init_param={
'method': 'random', 'fill_val': 1.0})

client.fit(ctx, data)

elif sys.argv[1] == "host":

ctx = create_ctx(guest)
ctx = create_ctx(host)
df = pd.read_csv(
'../../../../../../../examples/data/breast_homo_host.csv')
df['sample_id'] = [i for i in range(len(df))]
Expand All @@ -66,6 +72,15 @@ def create_ctx(local):
dtype="object")

data = reader.to_frame(ctx, df)

client = HomoLRClient(
50, 800, optimizer_param={
'method': 'adam', 'penalty': 'l1', 'aplha': 0.1, 'optimizer_para': {
'lr': 0.1}}, init_param={
'method': 'random', 'fill_val': 1.0})

client.fit(ctx, data)
else:
ctx = create_ctx(arbiter)

ctx = create_ctx(arbiter)
server = HomoLRServer()
server.fit(ctx)
14 changes: 1 addition & 13 deletions python/fate/ml/nn/algo/homo/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,7 @@

@dataclass
class FedAVGArguments(FedArguments):

"""
The arguemnt for FedAVG algorithm, used in FedAVGClient and FedAVGServer.
Attributes:
weighted_aggregate: bool
Whether to use weighted aggregation or not.
secure_aggregate: bool
Whether to use secure aggregation or not.
"""

weighted_aggregate: bool = field(default=True)
secure_aggregate: bool = field(default=False)
pass


class FedAVGCLient(FedTrainerClient):
Expand Down
4 changes: 2 additions & 2 deletions python/fate/ml/nn/trainer/test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def create_ctx(local):

if sys.argv[1] == "guest":
ctx = create_ctx(guest)
fed_args = FedArguments(aggregate_strategy='epochs', aggregate_freq=1, aggregator='secure_aggrefgate')
fed_args = FedArguments(aggregate_strategy='epochs', aggregate_freq=1, aggregator='secure_aggregate')
args = TrainingArguments(num_train_epochs=5, per_device_train_batch_size=16, logging_strategy='steps', logging_steps=5)
trainer = FedAVGCLient(ctx=ctx, model=model, fed_args=fed_args, training_args=args, loss_fn=t.nn.BCELoss(), optimizer=t.optim.SGD(model.parameters(), lr=0.01), train_set=ds)
trainer.train()

elif sys.argv[1] == "host":
ctx = create_ctx(host)
fed_args = FedArguments(aggregate_strategy='epochs', aggregate_freq=1, aggregator='plaintext')
fed_args = FedArguments(aggregate_strategy='epochs', aggregate_freq=1, aggregator='secure_aggregate')
args = TrainingArguments(num_train_epochs=5, per_device_train_batch_size=16)
trainer = FedAVGCLient(ctx=ctx, model=model, fed_args=fed_args, training_args=args, loss_fn=t.nn.BCELoss(), optimizer=t.optim.SGD(model.parameters(), lr=0.01), train_set=ds)
trainer.train()
Expand Down
1 change: 0 additions & 1 deletion python/fate/ml/utils/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@ def optimizer_factory(model_parameter, optimizer_type, optim_params):

def lr_scheduler_factory(optimizer, method, scheduler_param):
scheduler_method = method

if scheduler_method == 'constant':
return torch.optim.lr_scheduler.ConstantLR(
optimizer, **scheduler_param)
Expand Down

0 comments on commit 7bd3c59

Please sign in to comment.