Skip to content

Commit

Permalink
fix guest gradient computation(#4659)
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Wu <yolandawu131@gmail.com>
  • Loading branch information
nemirorox committed Jul 12, 2023
1 parent e803cdc commit 37a0809
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/fate/ml/glm/hetero/coordinated_linr/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def fit_model(self, ctx, train_data, validate_data=None):
batch_ctx.arbiter.put(loss=loss)

# gradient
g = 1 / h * X.T @ d
g = 1 / h * torch.matmul(X.T, d)
g = self.optimizer.add_regular_to_grad(g, w, self.init_param.get("fit_intercept"))
batch_ctx.arbiter.put("g_enc", g)
g = batch_ctx.arbiter.get("g")
Expand Down
3 changes: 2 additions & 1 deletion python/fate/ml/glm/hetero/coordinated_lr/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def fit_single_model(self, ctx: Context, train_data, validate_data=None):
loss -= 0.5 / h * torch.matmul(Y.T, Xw_h)
loss += 0.25 / h * torch.matmul(Xw.T, Xw_h)
if weight:
logger.info(f"weight: {weight.tolist()}")
d = d * weight
batch_ctx.hosts.put(d=d)

Expand All @@ -232,7 +233,7 @@ def fit_single_model(self, ctx: Context, train_data, validate_data=None):
batch_ctx.arbiter.put(loss=loss)

# gradient
g = 1 / h * X.T @ d
g = 1 / h * torch.matmul(X.T, d)
g = self.optimizer.add_regular_to_grad(g, w, self.init_param.get("fit_intercept"))
batch_ctx.arbiter.put("g_enc", g)
g = batch_ctx.arbiter.get("g")
Expand Down

0 comments on commit 37a0809

Please sign in to comment.