Skip to content

Commit

Permalink
Merge pull request #5365 from FederatedAI/feature-2.0.0-rc-sshe-glm
Browse files Browse the repository at this point in the history
Feature 2.0.0 rc sshe glm
  • Loading branch information
mgqa34 authored Dec 21, 2023
2 parents 80386ad + a25bb8d commit d7a4cba
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 22 deletions.
6 changes: 3 additions & 3 deletions examples/pipeline/sshe_linr/test_linr_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def main(config="../config.yaml", namespace=""):
reader_0.hosts[0].task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_host")
psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"])
linr_0 = SSHELinR("linr_0",
epochs=10,
batch_size=None,
learning_rate=0.05,
epochs=3,
batch_size=100,
learning_rate=0.15,
init_param={"fit_intercept": True},
cv_data=psi_0.outputs["output_data"],
cv_param={"n_splits": 3},
Expand Down
2 changes: 1 addition & 1 deletion examples/pipeline/sshe_linr/test_linr_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def main(config="../config.yaml", namespace=""):
reveal_every_epoch=False,
early_stop="diff",
reveal_loss_freq=1,
learning_rate=0.1)
learning_rate=0.15)
evaluation_0 = Evaluation("evaluation_0",
runtime_parties=dict(guest=guest),
default_eval_setting="regression",
Expand Down
12 changes: 6 additions & 6 deletions examples/pipeline/sshe_linr/test_linr_warm_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,29 +39,29 @@ def main(config="../config.yaml", namespace=""):
psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"])
linr_0 = SSHELinR("linr_0",
epochs=4,
batch_size=None,
batch_size=100,
init_param={"fit_intercept": True, "method": "zeros"},
train_data=psi_0.outputs["output_data"],
learning_rate=0.05,
learning_rate=0.15,
reveal_every_epoch=False,
early_stop="diff",
reveal_loss_freq=1,
)
linr_1 = SSHELinR("linr_1", train_data=psi_0.outputs["output_data"],
warm_start_model=linr_0.outputs["output_model"],
epochs=2,
batch_size=None,
learning_rate=0.05,
batch_size=100,
learning_rate=0.15,
reveal_every_epoch=True,
early_stop="diff",
reveal_loss_freq=1,
)

linr_2 = SSHELinR("linr_2", epochs=6,
batch_size=None,
batch_size=100,
init_param={"fit_intercept": True, "method": "zeros"},
train_data=psi_0.outputs["output_data"],
learning_rate=0.05,
learning_rate=0.15,
reveal_every_epoch=False,
early_stop="diff",
reveal_loss_freq=1,
Expand Down
2 changes: 1 addition & 1 deletion examples/pipeline/sshe_lr/test_lr_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def main(config="../config.yaml", namespace=""):
lr_0 = SSHELR("lr_0",
learning_rate=0.15,
epochs=2,
batch_size=None,
batch_size=300,
init_param={"fit_intercept": True},
cv_data=psi_0.outputs["output_data"],
cv_param={"n_splits": 3},
Expand Down
2 changes: 1 addition & 1 deletion examples/pipeline/sshe_lr/test_lr_multi_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def main(config="../config.yaml", namespace=""):
lr_0 = SSHELR("lr_0",
learning_rate=0.15,
epochs=10,
batch_size=None,
batch_size=300,
reveal_every_epoch=False,
early_stop="diff",
reveal_loss_freq=1,
Expand Down
12 changes: 6 additions & 6 deletions examples/pipeline/sshe_lr/test_lr_warm_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def main(config="../config.yaml", namespace=""):
psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"])
lr_0 = SSHELR("lr_0",
epochs=4,
batch_size=None,
learning_rate=0.05,
batch_size=300,
learning_rate=0.15,
init_param={"fit_intercept": True, "method": "zeros"},
train_data=psi_0.outputs["output_data"],
reveal_every_epoch=False,
Expand All @@ -56,16 +56,16 @@ def main(config="../config.yaml", namespace=""):
lr_1 = SSHELR("lr_1", train_data=psi_0.outputs["output_data"],
warm_start_model=lr_0.outputs["output_model"],
epochs=2,
batch_size=None,
learning_rate=0.05,
batch_size=300,
learning_rate=0.15,
reveal_every_epoch=False,
early_stop="diff",
reveal_loss_freq=1,
)

lr_2 = SSHELR("lr_2", epochs=6,
batch_size=None,
learning_rate=0.05,
batch_size=300,
learning_rate=0.15,
init_param={"fit_intercept": True, "method": "zeros"},
train_data=psi_0.outputs["output_data"],
reveal_every_epoch=False,
Expand Down
8 changes: 6 additions & 2 deletions python/fate/ml/glm/hetero/sshe/sshe_linr.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,13 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data
loss = loss_fn(z, y)
if i % self.reveal_loss_freq == 0:
if epoch_loss is None:
epoch_loss = loss.get(dst=rank_b) * h.shape[0]
epoch_loss = loss.get(dst=rank_b)
if ctx.is_on_guest:
epoch_loss = epoch_loss * h.shape[0]
else:
epoch_loss += loss.get(dst=rank_b) * h.shape[0]
batch_loss = loss.get(dst=rank_b)
if ctx.is_on_guest:
epoch_loss += batch_loss * h.shape[0]
loss.backward()
optimizer.step()
if epoch_loss is not None and ctx.is_on_guest:
Expand Down
8 changes: 6 additions & 2 deletions python/fate/ml/glm/hetero/sshe/sshe_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,13 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data
loss = loss_fn(z, y)
if i % self.reveal_loss_freq == 0:
if epoch_loss is None:
epoch_loss = loss.get(dst=rank_b) * h.shape[0]
epoch_loss = loss.get(dst=rank_b)
if epoch_loss:
epoch_loss = epoch_loss * h.shape[0]
else:
epoch_loss += loss.get(dst=rank_b) * h.shape[0]
batch_loss = loss.get(dst=rank_b)
if batch_loss:
epoch_loss += batch_loss * h.shape[0]
loss.backward()
optimizer.step()
if epoch_loss is not None and ctx.is_on_guest:
Expand Down

0 comments on commit d7a4cba

Please sign in to comment.