diff --git a/python/fate/ml/glm/hetero/sshe/sshe_linr.py b/python/fate/ml/glm/hetero/sshe/sshe_linr.py index 11193d541b..d8cdf8753c 100644 --- a/python/fate/ml/glm/hetero/sshe/sshe_linr.py +++ b/python/fate/ml/glm/hetero/sshe/sshe_linr.py @@ -168,6 +168,7 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data initialize_func = lambda x: self.w if self.init_param.get("fit_intercept"): train_data["intercept"] = 1.0 + train_data_n = train_data.shape[0] layer = SSHELinearRegressionLayer( ctx, in_features_a=ctx.mpc.option_call(lambda: train_data.shape[1], dst=rank_a), @@ -210,12 +211,13 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data loss = loss_fn(z, y) if i % self.reveal_loss_freq == 0: if epoch_loss is None: - epoch_loss = loss.get(dst=rank_b) + epoch_loss = loss.get(dst=rank_b) * h.shape[0] else: - epoch_loss += loss.get(dst=rank_b) + epoch_loss += loss.get(dst=rank_b) * h.shape[0] loss.backward() optimizer.step() if epoch_loss is not None and ctx.is_on_guest: + epoch_loss = epoch_loss / train_data_n epoch_ctx.metrics.log_loss("linr_loss", epoch_loss.tolist()) # if self.reveal_every_epoch: # wa_p = wa.get_plain_text(dst=rank_a) diff --git a/python/fate/ml/glm/hetero/sshe/sshe_lr.py b/python/fate/ml/glm/hetero/sshe/sshe_lr.py index 0de7e172e6..bd362e8838 100644 --- a/python/fate/ml/glm/hetero/sshe/sshe_lr.py +++ b/python/fate/ml/glm/hetero/sshe/sshe_lr.py @@ -272,6 +272,7 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data initialize_func = lambda x: self.w if self.init_param.get("fit_intercept"): train_data["intercept"] = 1.0 + train_data_n = train_data.shape[0] layer = SSHELogisticRegressionLayer( ctx, in_features_a=ctx.mpc.option_call(lambda: train_data.shape[1], dst=rank_a), @@ -314,12 +315,13 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data loss = loss_fn(z, y) if i % self.reveal_loss_freq == 0: if epoch_loss is None: - epoch_loss = loss.get(dst=rank_b) + epoch_loss = loss.get(dst=rank_b) * h.shape[0] else: - epoch_loss += loss.get(dst=rank_b) + epoch_loss += loss.get(dst=rank_b) * h.shape[0] loss.backward() optimizer.step() if epoch_loss is not None and ctx.is_on_guest: + epoch_loss = epoch_loss / train_data_n epoch_ctx.metrics.log_loss("lr_loss", epoch_loss.tolist()) # if self.reveal_every_epoch: # wa_p = wa.get_plain_text(dst=rank_a)