Skip to content

Commit

Permalink
Merge pull request #5359 from FederatedAI/feature-2.0.0-rc-sshe-glm
Browse files Browse the repository at this point in the history
fix sshe loss
  • Loading branch information
mgqa34 authored Dec 21, 2023
2 parents 19dddee + c0f49ec commit 39f457d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
6 changes: 4 additions & 2 deletions python/fate/ml/glm/hetero/sshe/sshe_linr.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data
initialize_func = lambda x: self.w
if self.init_param.get("fit_intercept"):
train_data["intercept"] = 1.0
train_data_n = train_data.shape[0]
layer = SSHELinearRegressionLayer(
ctx,
in_features_a=ctx.mpc.option_call(lambda: train_data.shape[1], dst=rank_a),
Expand Down Expand Up @@ -210,12 +211,13 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data
loss = loss_fn(z, y)
if i % self.reveal_loss_freq == 0:
if epoch_loss is None:
epoch_loss = loss.get(dst=rank_b)
epoch_loss = loss.get(dst=rank_b) * h.shape[0]
else:
epoch_loss += loss.get(dst=rank_b)
epoch_loss += loss.get(dst=rank_b) * h.shape[0]
loss.backward()
optimizer.step()
if epoch_loss is not None and ctx.is_on_guest:
epoch_loss = epoch_loss / train_data_n
epoch_ctx.metrics.log_loss("linr_loss", epoch_loss.tolist())
# if self.reveal_every_epoch:
# wa_p = wa.get_plain_text(dst=rank_a)
Expand Down
6 changes: 4 additions & 2 deletions python/fate/ml/glm/hetero/sshe/sshe_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data
initialize_func = lambda x: self.w
if self.init_param.get("fit_intercept"):
train_data["intercept"] = 1.0
train_data_n = train_data.shape[0]
layer = SSHELogisticRegressionLayer(
ctx,
in_features_a=ctx.mpc.option_call(lambda: train_data.shape[1], dst=rank_a),
Expand Down Expand Up @@ -314,12 +315,13 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data
loss = loss_fn(z, y)
if i % self.reveal_loss_freq == 0:
if epoch_loss is None:
epoch_loss = loss.get(dst=rank_b)
epoch_loss = loss.get(dst=rank_b) * h.shape[0]
else:
epoch_loss += loss.get(dst=rank_b)
epoch_loss += loss.get(dst=rank_b) * h.shape[0]
loss.backward()
optimizer.step()
if epoch_loss is not None and ctx.is_on_guest:
epoch_loss = epoch_loss / train_data_n
epoch_ctx.metrics.log_loss("lr_loss", epoch_loss.tolist())
# if self.reveal_every_epoch:
# wa_p = wa.get_plain_text(dst=rank_a)
Expand Down

0 comments on commit 39f457d

Please sign in to comment.