Skip to content

Commit

Permalink
edit glm(#4659)
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Wu <yolandawu131@gmail.com>
  • Loading branch information
nemirorox committed Jul 3, 2023
1 parent 8214dde commit 6b32e36
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 17 deletions.
9 changes: 6 additions & 3 deletions python/fate/ml/glm/coordinated_lr/arbiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def fit(self, ctx: Context) -> None:
if label_count > 2:
self.ovr = True
self.estimator = {}
for i, class_ctx in ctx.range(range(label_count)):
for i, class_ctx in ctx.ctxs_range(label_count):
optimizer = copy.deepcopy(self.optimizer)
lr_scheduler = copy.deepcopy(self.lr_scheduler)
single_estimator = CoordinatedLREstimatorArbiter(max_iter=self.max_iter,
Expand Down Expand Up @@ -147,11 +147,14 @@ def fit_single_model(self, ctx, decryptor):
else:
optimizer_ready = True
self.start_iter = self.end_iter + 1
for i, iter_ctx in ctx.range(self.start_iter, self.max_iter):
# temp code start
# for i, iter_ctx in ctx.ctxs_range(self.start_iter, self.max_iter):
for i, iter_ctx in ctx.ctxs_range(self.max_iter):
# temp code ends
iter_loss = None
iter_g = None
self.optimizer.set_iters(i)
for batch_ctx, _ in iter_ctx.iter(batch_loader):
for batch_ctx, _ in iter_ctx.ctxs_zip(batch_loader):

g_guest_enc = batch_ctx.guest.get("g_enc")
g_guest = decryptor.decrypt(g_guest_enc)
Expand Down
13 changes: 7 additions & 6 deletions python/fate/ml/glm/coordinated_lr/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None:
label_count = train_data_binarized_label.shape[1]
ctx.arbiter.put("label_count", label_count)
ctx.hosts.put("label_count", label_count)
self.labels = [label_name.split('_')[1] for label_name in label_count.columns]
self.labels = [label_name.split('_')[1] for label_name in train_data_binarized_label.columns]
with_weight = train_data.weight is not None
"""
# temp code start
Expand Down Expand Up @@ -219,16 +219,16 @@ def fit_single_model(self, ctx, train_data, validate_data=None, with_weight=Fals
"""if train_data.weight:
self.with_weight = True"""

for i, iter_ctx in ctx.ctxs_range(self.start_iter, self.max_iter):
# temp code start
# for i, iter_ctx in ctx.range(self.max_iter):
# for i, iter_ctx in ctx.ctxs_range(self.start_iter, self.max_iter):
# temp code start
for i, iter_ctx in ctx.ctxs_range(self.max_iter):
# temp code end
logger.info(f"start iter {i}")
j = 0
self.optimizer.set_iters(i)
logger.info(f"self.optimizer set iters{i}")
# todo: if self.with_weight: include weight in batch result
for batch_ctx, (X, Y, weight) in iter_ctx.iter(batch_loader):
for batch_ctx, (X, Y) in iter_ctx.ctxs_zip(batch_loader):
# temp code start
# for batch_ctx, (X, Y) in iter_ctx.iter(batch_loader):
# for batch_ctx, X, Y in [(iter_ctx, train_data, train_data.label)]:
Expand Down Expand Up @@ -266,8 +266,9 @@ def fit_single_model(self, ctx, train_data, validate_data=None, with_weight=Fals
g = self.optimizer.add_regular_to_grad(X.T @ d, w, self.init_param.fit_intercept)
batch_ctx.arbiter.put("g_enc", g)
g = batch_ctx.arbiter.get("g")
# @todo: optimizer.step()?
# self.optimizer.step(g)
w = self.optimizer.update_weights(w, g, self.init_param.fit_intercept, self.lr_scheduler.lr)

logger.info(f"w={w}")
j += 1
self.is_converged = ctx.arbiter("converge_flag").get()
Expand Down
8 changes: 4 additions & 4 deletions python/fate/ml/glm/coordinated_lr/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,16 @@ def fit_single_model(self, ctx: Context, encryptor, train_data, validate_data=No
batch_loader = DataLoader(train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="host")
if self.end_iter >= 0:
self.start_iter = self.end_iter + 1
for i, iter_ctx in ctx.ctxs_range(self.start_iter, self.max_iter):
# temp code start
# for i, iter_ctx in ctx.range(self.max_iter):
# for i, iter_ctx in ctx.ctxs_range(self.start_iter, self.max_iter):
# temp code start
for i, iter_ctx in ctx.ctxs_range(self.max_iter):
# temp code end
logger.info(f"start iter {i}")
j = 0
self.optimizer.set_iters(i)
logger.info(f"self.optimizer set iters{i}")
# temp code start
for batch_ctx, X in iter_ctx.iter(batch_loader):
for batch_ctx, X in iter_ctx.ctxs_zip(batch_loader):
# for batch_ctx, X in zip([iter_ctx], [train_data]):
# temp code end
# h = X.shape[0]
Expand Down
4 changes: 2 additions & 2 deletions python/fate/ml/glm/hetero_linr/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ def fit_model(self, ctx, train_data, validate_data=None, with_weight=False):
batch_loader = dataframe.DataLoader(
train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="guest", sync_arbiter=True,
# with_weight=True
) # @todo: include batch weight
)
if self.end_iter >= 0:
self.start_iter = self.end_iter + 1
for i, iter_ctx in ctx.range(self.start_iter, self.max_iter):
logger.info(f"start iter {i}")
j = 0
self.optimizer.set_iters(i)
# for batch_ctx, (X, Y, weight) in iter_ctx.iter(batch_loader):
for batch_ctx, X, Y in iter_ctx.iter(batch_loader):
for batch_ctx, X, Y in iter_ctx.ctxs_zip(batch_loader):
h = X.shape[0]
Xw = torch.matmul(X, w)
d = Xw - Y
Expand Down
4 changes: 2 additions & 2 deletions python/fate/ml/glm/hetero_linr/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ def fit_model(self, ctx: Context, encryptor, train_data, validate_data=None) ->
self.start_iter = self.end_iter + 1
"""for i, iter_ctx in ctx.range(self.start_iter, self.max_iter):"""
# temp code start
for i, iter_ctx in ctx.range(self.max_iter):
for i, iter_ctx in ctx.ctxs_range(self.max_iter):
# temp code end
logger.info(f"start iter {i}")
j = 0
self.optimizer.set_iters(i)
for batch_ctx, X in iter_ctx.iter(batch_loader):
for batch_ctx, X in iter_ctx.ctxs_zip(batch_loader):
# h = X.shape[0]
logger.info(f"start batch {j}")
Xw_h = torch.matmul(X, w)
Expand Down

0 comments on commit 6b32e36

Please sign in to comment.