Skip to content

Commit

Permalink
edit hetero lr for test(#4659)
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Wu <yolandawu131@gmail.com>
  • Loading branch information
nemirorox committed May 4, 2023
1 parent 1894f72 commit 693ec96
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 17 deletions.
16 changes: 6 additions & 10 deletions python/fate/components/components/hetero_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,11 @@ def train_guest(ctx, train_data, validate_data, train_output_data, output_model,
module = HeteroLrModuleGuest(max_iter=max_iter, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
init_param=init_param, threshold=threshold)
train_data = sub_ctx.reader(train_data).read_dataframe()
train_data = sub_ctx.reader(train_data).read_dataframe().data

if validate_data is not None:
validate_data = sub_ctx.reader(validate_data).read_dataframe()
# temp code start
train_data = train_data.data
# temp code end
validate_data = sub_ctx.reader(validate_data).read_dataframe().data

module.fit(sub_ctx, train_data, validate_data)
model = module.get_model()
with output_model as model_writer:
Expand All @@ -150,13 +148,11 @@ def train_host(ctx, train_data, validate_data, train_output_data, output_model,
module = HeteroLrModuleHost(max_iter=max_iter, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
init_param=init_param)
train_data = sub_ctx.reader(train_data).read_dataframe()
train_data = sub_ctx.reader(train_data).read_dataframe().data

if validate_data is not None:
validate_data = sub_ctx.reader(validate_data).read_dataframe()
# temp code start
train_data = train_data.data
# temp code end
validate_data = sub_ctx.reader(validate_data).read_dataframe().data

module.fit(sub_ctx, train_data, validate_data)
model = module.get_model()
with output_model as model_writer:
Expand Down
20 changes: 14 additions & 6 deletions python/fate/ml/glm/hetero_lr/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def __init__(
self.start_iter = 0
self.end_iter = -1
self.is_converged = False
self.with_weight = False

def fit_single_model(self, ctx, train_data, validate_data=None, with_weight=False):
"""
Expand All @@ -188,8 +189,9 @@ def fit_single_model(self, ctx, train_data, validate_data=None, with_weight=Fals
loss = log2 - (1/N)*0.5*∑ywx + (1/N)*0.125*[∑(Wg*Xg)^2 + ∑(Wh*Xh)^2 + 2 * ∑(Wg*Xg * Wh*Xh)]
"""
coef_count = train_data.shape[1]
# @todo: need to make sure add single-valued column works
if self.init_param.fit_intercept:
train_data["intercept"] = 1
train_data["intercept"] = 1.0
coef_count += 1

# temp code start
Expand All @@ -204,15 +206,13 @@ def fit_single_model(self, ctx, train_data, validate_data=None, with_weight=Fals
self.optimizer.init_optimizer(model_parameter_length=w.size()[0])
# temp code end

"""batch_loader = dataframe.DataLoader(
train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="guest", sync_arbiter=True,
return_weight=True
) # @todo: include batch weight"""
batch_loader = dataframe.DataLoader(
train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="guest", sync_arbiter=True
)
if self.end_iter >= 0:
self.start_iter = self.end_iter + 1
"""if train_data.weight:
self.with_weight = True"""
"""for i, iter_ctx in ctx.range(self.start_iter, self.max_iter):"""
# temp code start
for i, iter_ctx in ctx.range(self.max_iter):
Expand All @@ -221,8 +221,16 @@ def fit_single_model(self, ctx, train_data, validate_data=None, with_weight=Fals
j = 0
self.optimizer.set_iters(i)
logger.info(f"self.optimizer set iters{i}")
# todo: if self.with_weight: include weight in batch result
# for batch_ctx, (X, Y, weight) in iter_ctx.iter(batch_loader):
for batch_ctx, (X, Y) in iter_ctx.iter(batch_loader):
# temp code start
# for batch_ctx, (X, Y) in iter_ctx.iter(batch_loader):
for batch_ctx, X, Y in [(iter_ctx, train_data, train_data.label)]:
# temp code end
logger.info(f"X: {X}, Y: {Y}")
# temp code start
X = X.values.as_tensor()
# temp code end
h = X.shape[0]

Xw = torch.matmul(X, w)
Expand Down
8 changes: 7 additions & 1 deletion python/fate/ml/glm/hetero_lr/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,15 @@ def fit_single_model(self, ctx: Context, encryptor, train_data, validate_data=No
j = 0
self.optimizer.set_iters(i)
logger.info(f"self.optimizer set iters{i}")
for batch_ctx, X in iter_ctx.iter(batch_loader):
# temp code start
# for batch_ctx, X in iter_ctx.iter(batch_loader):
for batch_ctx, X in zip([iter_ctx], [train_data]):
# temp code end
# h = X.shape[0]
logger.info(f"start batch {j}")
# temp code start
X = X.values.as_tensor()
# temp code end
Xw_h = 0.25 * torch.matmul(X, w)
if self.optimizer.l1_penalty or self.optimizer.l2_penalty:
Xw_h = self.optimizer.add_regular_to_grad(Xw_h, w)
Expand Down

0 comments on commit 693ec96

Please sign in to comment.