diff --git a/python/fate/ml/glm/coordinated_lr/arbiter.py b/python/fate/ml/glm/coordinated_lr/arbiter.py index c5480e71b8..1f211b3787 100644 --- a/python/fate/ml/glm/coordinated_lr/arbiter.py +++ b/python/fate/ml/glm/coordinated_lr/arbiter.py @@ -68,7 +68,7 @@ def fit(self, ctx: Context) -> None: if label_count > 2: self.ovr = True self.estimator = {} - for i, class_ctx in ctx.range(range(label_count)): + for i, class_ctx in ctx.ctxs_range(label_count): optimizer = copy.deepcopy(self.optimizer) lr_scheduler = copy.deepcopy(self.lr_scheduler) single_estimator = CoordinatedLREstimatorArbiter(max_iter=self.max_iter, @@ -147,11 +147,14 @@ def fit_single_model(self, ctx, decryptor): else: optimizer_ready = True self.start_iter = self.end_iter + 1 - for i, iter_ctx in ctx.range(self.start_iter, self.max_iter): + # temp code start + # for i, iter_ctx in ctx.ctxs_range(self.start_iter, self.max_iter): + for i, iter_ctx in ctx.ctxs_range(self.max_iter): + # temp code ends iter_loss = None iter_g = None self.optimizer.set_iters(i) - for batch_ctx, _ in iter_ctx.iter(batch_loader): + for batch_ctx, _ in iter_ctx.ctxs_zip(batch_loader): g_guest_enc = batch_ctx.guest.get("g_enc") g_guest = decryptor.decrypt(g_guest_enc) diff --git a/python/fate/ml/glm/coordinated_lr/guest.py b/python/fate/ml/glm/coordinated_lr/guest.py index 0a9dbe4713..e95d0f30ab 100644 --- a/python/fate/ml/glm/coordinated_lr/guest.py +++ b/python/fate/ml/glm/coordinated_lr/guest.py @@ -66,7 +66,7 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None: label_count = train_data_binarized_label.shape[1] ctx.arbiter.put("label_count", label_count) ctx.hosts.put("label_count", label_count) - self.labels = [label_name.split('_')[1] for label_name in label_count.columns] + self.labels = [label_name.split('_')[1] for label_name in train_data_binarized_label.columns] with_weight = train_data.weight is not None """ # temp code start @@ -219,16 +219,16 @@ def fit_single_model(self, ctx, train_data, validate_data=None, with_weight=Fals """if train_data.weight: self.with_weight = True""" - for i, iter_ctx in ctx.ctxs_range(self.start_iter, self.max_iter): - # temp code start - # for i, iter_ctx in ctx.range(self.max_iter): + # for i, iter_ctx in ctx.ctxs_range(self.start_iter, self.max_iter): + # temp code start + for i, iter_ctx in ctx.ctxs_range(self.max_iter): # temp code end logger.info(f"start iter {i}") j = 0 self.optimizer.set_iters(i) logger.info(f"self.optimizer set iters{i}") # todo: if self.with_weight: include weight in batch result - for batch_ctx, (X, Y, weight) in iter_ctx.iter(batch_loader): + for batch_ctx, (X, Y) in iter_ctx.ctxs_zip(batch_loader): # temp code start # for batch_ctx, (X, Y) in iter_ctx.iter(batch_loader): # for batch_ctx, X, Y in [(iter_ctx, train_data, train_data.label)]: @@ -266,8 +266,9 @@ def fit_single_model(self, ctx, train_data, validate_data=None, with_weight=Fals g = self.optimizer.add_regular_to_grad(X.T @ d, w, self.init_param.fit_intercept) batch_ctx.arbiter.put("g_enc", g) g = batch_ctx.arbiter.get("g") - # @todo: optimizer.step()? + # self.optimizer.step(g) w = self.optimizer.update_weights(w, g, self.init_param.fit_intercept, self.lr_scheduler.lr) + logger.info(f"w={w}") j += 1 self.is_converged = ctx.arbiter("converge_flag").get() diff --git a/python/fate/ml/glm/coordinated_lr/host.py b/python/fate/ml/glm/coordinated_lr/host.py index b12784eafc..fdb45316df 100644 --- a/python/fate/ml/glm/coordinated_lr/host.py +++ b/python/fate/ml/glm/coordinated_lr/host.py @@ -159,16 +159,16 @@ def fit_single_model(self, ctx: Context, encryptor, train_data, validate_data=No batch_loader = DataLoader(train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="host") if self.end_iter >= 0: self.start_iter = self.end_iter + 1 - for i, iter_ctx in ctx.ctxs_range(self.start_iter, self.max_iter): - # temp code start - # for i, iter_ctx in ctx.range(self.max_iter): + # for i, iter_ctx in ctx.ctxs_range(self.start_iter, self.max_iter): + # temp code start + for i, iter_ctx in ctx.ctxs_range(self.max_iter): # temp code end logger.info(f"start iter {i}") j = 0 self.optimizer.set_iters(i) logger.info(f"self.optimizer set iters{i}") # temp code start - for batch_ctx, X in iter_ctx.iter(batch_loader): + for batch_ctx, X in iter_ctx.ctxs_zip(batch_loader): # for batch_ctx, X in zip([iter_ctx], [train_data]): # temp code end # h = X.shape[0] diff --git a/python/fate/ml/glm/hetero_linr/guest.py b/python/fate/ml/glm/hetero_linr/guest.py index db9b3a75c7..f832a5dc2a 100644 --- a/python/fate/ml/glm/hetero_linr/guest.py +++ b/python/fate/ml/glm/hetero_linr/guest.py @@ -119,7 +119,7 @@ def fit_model(self, ctx, train_data, validate_data=None, with_weight=False): batch_loader = dataframe.DataLoader( train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="guest", sync_arbiter=True, # with_weight=True - ) # @todo: include batch weight + ) if self.end_iter >= 0: self.start_iter = self.end_iter + 1 for i, iter_ctx in ctx.range(self.start_iter, self.max_iter): @@ -127,7 +127,7 @@ def fit_model(self, ctx, train_data, validate_data=None, with_weight=False): j = 0 self.optimizer.set_iters(i) # for batch_ctx, (X, Y, weight) in iter_ctx.iter(batch_loader): - for batch_ctx, X, Y in iter_ctx.iter(batch_loader): + for batch_ctx, X, Y in iter_ctx.ctxs_zip(batch_loader): h = X.shape[0] Xw = torch.matmul(X, w) d = Xw - Y diff --git a/python/fate/ml/glm/hetero_linr/host.py b/python/fate/ml/glm/hetero_linr/host.py index e889caf3ba..f7b185e2cf 100644 --- a/python/fate/ml/glm/hetero_linr/host.py +++ b/python/fate/ml/glm/hetero_linr/host.py @@ -107,12 +107,12 @@ def fit_model(self, ctx: Context, encryptor, train_data, validate_data=None) -> self.start_iter = self.end_iter + 1 """for i, iter_ctx in ctx.range(self.start_iter, self.max_iter):""" # temp code start - for i, iter_ctx in ctx.range(self.max_iter): + for i, iter_ctx in ctx.ctxs_range(self.max_iter): # temp code end logger.info(f"start iter {i}") j = 0 self.optimizer.set_iters(i) - for batch_ctx, X in iter_ctx.iter(batch_loader): + for batch_ctx, X in iter_ctx.ctxs_zip(batch_loader): # h = X.shape[0] logger.info(f"start batch {j}") Xw_h = torch.matmul(X, w)