Skip to content

Commit

Permalink
coordinated lr & linr support warm start(#4659)
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Wu <yolandawu131@gmail.com>
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
nemirorox authored and sagewe committed Jul 21, 2023
1 parent 44a45d0 commit 10691e9
Show file tree
Hide file tree
Showing 10 changed files with 392 additions and 225 deletions.
57 changes: 42 additions & 15 deletions python/fate/components/components/coordinated_linr.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def train(
desc="Model param init setting."),
train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]),
output_model: cpn.json_model_output(roles=[GUEST, HOST, ARBITER]),
warm_start_model: cpn.json_model_input(roles=[GUEST, HOST, ARBITER], optional=True),

):
logger.info(f"enter coordinated linr train")
# temp code start
Expand All @@ -69,15 +71,16 @@ def train(
if role.is_guest:
train_guest(
ctx, train_data, validate_data, train_output_data, output_model, epochs,
batch_size, optimizer, learning_rate_scheduler, init_param
batch_size, optimizer, learning_rate_scheduler, init_param, warm_start_model
)
elif role.is_host:
train_host(
ctx, train_data, validate_data, train_output_data, output_model, epochs,
batch_size, optimizer, learning_rate_scheduler, init_param
batch_size, optimizer, learning_rate_scheduler, init_param, warm_start_model
)
elif role.is_arbiter:
train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer, learning_rate_scheduler, output_model)
train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer, learning_rate_scheduler, output_model,
warm_start_model)


@coordinated_linr.predict()
Expand Down Expand Up @@ -204,12 +207,19 @@ def cross_validation(


def train_guest(ctx, train_data, validate_data, train_output_data, output_model, epochs,
batch_size, optimizer_param, learning_rate_param, init_param):
batch_size, optimizer_param, learning_rate_param, init_param, input_model):
if input_model is not None:
logger.info(f"warm start model provided")
model = input_model.read()
module = CoordinatedLinRModuleGuest.from_model(model)
module.epochs = epochs
module.batch_size = batch_size
else:
module = CoordinatedLinRModuleGuest(epochs=epochs, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
init_param=init_param)
logger.info(f"coordinated linr guest start train")
sub_ctx = ctx.sub_ctx("train")
module = CoordinatedLinRModuleGuest(epochs=epochs, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
init_param=init_param)
train_data = train_data.read()
if validate_data is not None:
validate_data = validate_data.read()
Expand All @@ -224,6 +234,7 @@ def train_guest(ctx, train_data, validate_data, train_output_data, output_model,
predict_result = transform_to_predict_result(train_data, predict_score,
data_type="train")
if validate_data is not None:
sub_ctx = ctx.sub_ctx("validate_predict")
predict_score = module.predict(sub_ctx, validate_data)
validate_predict_result = transform_to_predict_result(validate_data, predict_score,
data_type="validate")
Expand All @@ -232,12 +243,20 @@ def train_guest(ctx, train_data, validate_data, train_output_data, output_model,


def train_host(ctx, train_data, validate_data, train_output_data, output_model, epochs, batch_size,
optimizer_param, learning_rate_param, init_param):
optimizer_param, learning_rate_param, init_param, input_model):
if input_model is not None:
logger.info(f"warm start model provided")
model = input_model.read()
module = CoordinatedLinRModuleHost.from_model(model)
module.epochs = epochs
module.batch_size = batch_size
else:
module = CoordinatedLinRModuleHost(epochs=epochs, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
init_param=init_param)
logger.info(f"coordinated linr host start train")
sub_ctx = ctx.sub_ctx("train")
module = CoordinatedLinRModuleHost(epochs=epochs, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
init_param=init_param)

train_data = train_data.read()
if validate_data is not None:
validate_data = validate_data.read()
Expand All @@ -249,17 +268,25 @@ def train_host(ctx, train_data, validate_data, train_output_data, output_model,
sub_ctx = ctx.sub_ctx("predict")
module.predict(sub_ctx, train_data)
if validate_data is not None:
sub_ctx = ctx.sub_ctx("validate_predict")
module.predict(sub_ctx, validate_data)


def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param,
learning_rate_param, output_model):
learning_rate_param, output_model, input_model):
if input_model is not None:
logger.info(f"warm start model provided")
model = input_model.read()
module = CoordinatedLinRModuleArbiter.from_model(model)
module.epochs = epochs
module.batch_size = batch_size
else:
module = CoordinatedLinRModuleArbiter(epochs=epochs, early_stop=early_stop, tol=tol, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
)
logger.info(f"coordinated linr arbiter start train")

sub_ctx = ctx.sub_ctx("train")
module = CoordinatedLinRModuleArbiter(epochs=epochs, early_stop=early_stop, tol=tol, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
)
module.fit(sub_ctx)

model = module.get_model()
Expand Down
141 changes: 88 additions & 53 deletions python/fate/components/components/coordinated_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,25 @@ def train(
default="diff",
desc="early stopping criterion, choose from {weight_diff, diff, abs, val_metrics}",
),
init_param: cpn.parameter(
type=params.init_param(),
default=params.InitParam(method="zeros", fit_intercept=True),
desc="Model param init setting.",
),
threshold: cpn.parameter(
type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data"
),
train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]),
output_model: cpn.json_model_output(roles=[GUEST, HOST, ARBITER]),
init_param: cpn.parameter(
type=params.init_param(),
default=params.InitParam(method="zeros", fit_intercept=True),
desc="Model param init setting.",
),
threshold: cpn.parameter(
type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data"
),
train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]),
output_model: cpn.json_model_output(roles=[GUEST, HOST, ARBITER]),
warm_start_model: cpn.json_model_input(roles=[GUEST, HOST, ARBITER], optional=True),
):
logger.info(f"enter coordinated lr train")
# temp code start
optimizer = optimizer.dict()
learning_rate_scheduler = learning_rate_scheduler.dict()
init_param = init_param.dict()
# temp code end

if role.is_guest:
train_guest(
ctx,
Expand All @@ -90,6 +92,7 @@ def train(
learning_rate_scheduler,
init_param,
threshold,
warm_start_model
)
elif role.is_host:
train_host(
Expand All @@ -103,9 +106,17 @@ def train(
optimizer,
learning_rate_scheduler,
init_param,
warm_start_model
)
elif role.is_arbiter:
train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer, learning_rate_scheduler, output_model)
train_arbiter(ctx,
epochs,
early_stop,
tol, batch_size,
optimizer,
learning_rate_scheduler,
output_model,
warm_start_model)


@coordinated_lr.predict()
Expand Down Expand Up @@ -242,28 +253,34 @@ def train_guest(
ctx,
train_data,
validate_data,
train_output_data,
output_model,
epochs,
batch_size,
optimizer_param,
learning_rate_param,
init_param,
threshold,
train_output_data,
output_model,
epochs,
batch_size,
optimizer_param,
learning_rate_param,
init_param,
threshold,
input_model
):
from fate.arch.dataframe import DataFrame

if input_model is not None:
logger.info(f"warm start model provided")
model = input_model.read()
module = CoordinatedLRModuleGuest.from_model(model)
module.epochs = epochs
module.batch_size = batch_size
else:
module = CoordinatedLRModuleGuest(
epochs=epochs,
batch_size=batch_size,
optimizer_param=optimizer_param,
learning_rate_param=learning_rate_param,
init_param=init_param,
threshold=threshold,
)
# optimizer = optimizer_factory(optimizer_param)
logger.info(f"coordinated lr guest start train")
sub_ctx = ctx.sub_ctx("train")
module = CoordinatedLRModuleGuest(
epochs=epochs,
batch_size=batch_size,
optimizer_param=optimizer_param,
learning_rate_param=learning_rate_param,
init_param=init_param,
threshold=threshold,
)
train_data = train_data.read()

if validate_data is not None:
Expand All @@ -281,6 +298,7 @@ def train_guest(
train_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, data_type="train"
)
if validate_data is not None:
sub_ctx = ctx.sub_ctx("validate_predict")
predict_score = module.predict(sub_ctx, validate_data)
validate_predict_result = transform_to_predict_result(
validate_data,
Expand All @@ -297,24 +315,32 @@ def train_guest(
def train_host(
ctx,
train_data,
validate_data,
train_output_data,
output_model,
epochs,
batch_size,
optimizer_param,
learning_rate_param,
init_param,
validate_data,
train_output_data,
output_model,
epochs,
batch_size,
optimizer_param,
learning_rate_param,
init_param,
input_model
):
if input_model is not None:
logger.info(f"warm start model provided")
model = input_model.read()
module = CoordinatedLRModuleHost.from_model(model)
module.epochs = epochs
module.batch_size = batch_size
else:
module = CoordinatedLRModuleHost(
epochs=epochs,
batch_size=batch_size,
optimizer_param=optimizer_param,
learning_rate_param=learning_rate_param,
init_param=init_param,
)
logger.info(f"coordinated lr host start train")
sub_ctx = ctx.sub_ctx("train")
module = CoordinatedLRModuleHost(
epochs=epochs,
batch_size=batch_size,
optimizer_param=optimizer_param,
learning_rate_param=learning_rate_param,
init_param=init_param,
)
train_data = train_data.read()

if validate_data is not None:
Expand All @@ -327,20 +353,29 @@ def train_host(
sub_ctx = ctx.sub_ctx("predict")
module.predict(sub_ctx, train_data)
if validate_data is not None:
sub_ctx = ctx.sub_ctx("validate_predict")
module.predict(sub_ctx, validate_data)


def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_scheduler, output_model):
def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_scheduler, output_model,
input_model):
if input_model is not None:
logger.info(f"warm start model provided")
model = input_model.read()
module = CoordinatedLRModuleArbiter.from_model(model)
module.epochs = epochs
module.batch_size = batch_size
else:
module = CoordinatedLRModuleArbiter(
epochs=epochs,
early_stop=early_stop,
tol=tol,
batch_size=batch_size,
optimizer_param=optimizer_param,
learning_rate_param=learning_rate_scheduler,
)
logger.info(f"coordinated lr arbiter start train")
sub_ctx = ctx.sub_ctx("train")
module = CoordinatedLRModuleArbiter(
epochs=epochs,
early_stop=early_stop,
tol=tol,
batch_size=batch_size,
optimizer_param=optimizer_param,
learning_rate_param=learning_rate_scheduler,
)
module.fit(sub_ctx)
model = module.get_model()
output_model.write(model, metadata={})
Expand Down
40 changes: 22 additions & 18 deletions python/fate/ml/glm/hetero/coordinated_linr/arbiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,23 @@ def __init__(
def fit(self, ctx: Context) -> None:
encryptor, decryptor = ctx.cipher.phe.keygen(options=dict(key_length=2048))
ctx.hosts("encryptor").put(encryptor)
optimizer = Optimizer(
self.optimizer_param["method"],
self.optimizer_param["penalty"],
self.optimizer_param["alpha"],
self.optimizer_param["optimizer_params"],
)
lr_scheduler = LRScheduler(self.learning_rate_param["method"],
self.learning_rate_param["scheduler_params"])
single_estimator = HeteroLinrEstimatorArbiter(epochs=self.epochs,
early_stop=self.early_stop,
tol=self.tol,
batch_size=self.batch_size,
optimizer=optimizer,
learning_rate_scheduler=lr_scheduler)
single_estimator.fit_model(ctx, decryptor)
self.estimator = single_estimator
if self.estimator is None:
optimizer = Optimizer(
self.optimizer_param["method"],
self.optimizer_param["penalty"],
self.optimizer_param["alpha"],
self.optimizer_param["optimizer_params"],
)
lr_scheduler = LRScheduler(self.learning_rate_param["method"],
self.learning_rate_param["scheduler_params"])
single_estimator = HeteroLinrEstimatorArbiter(epochs=self.epochs,
early_stop=self.early_stop,
tol=self.tol,
batch_size=self.batch_size,
optimizer=optimizer,
learning_rate_scheduler=lr_scheduler)
self.estimator = single_estimator
self.estimator.fit_model(ctx, decryptor)

def get_model(self):
return {
Expand All @@ -76,6 +77,7 @@ def get_model(self):
"optimizer_param": self.optimizer_param},
}

@classmethod
def from_model(cls, model):
linr = CoordinatedLinRModuleArbiter(model["meta"]["epochs"],
model["meta"]["early_stop"],
Expand Down Expand Up @@ -107,7 +109,8 @@ def __init__(
self.optimizer = optimizer
self.lr_scheduler = learning_rate_scheduler

self.converge_func = converge_func_factory(early_stop, tol)
if early_stop is not None:
self.converge_func = converge_func_factory(early_stop, tol)
self.start_epoch = 0
self.end_epoch = -1
self.is_converged = False
Expand All @@ -121,7 +124,7 @@ def fit_model(self, ctx, decryptor):
optimizer_ready = False
else:
optimizer_ready = True
self.start_epoch = self.end_epoch + 1
# self.start_epoch = self.end_epoch + 1

for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs):
iter_loss = None
Expand Down Expand Up @@ -204,4 +207,5 @@ def restore(self, model):
self.lr_scheduler.load_state_dict(model["lr_scheduler"], self.optimizer.optimizer)
self.end_epoch = model["end_epoch"]
self.is_converged = model["is_converged"]
self.converge_func = converge_func_factory(self.early_stop, self.tol)
# self.start_epoch = model["end_epoch"] + 1
Loading

0 comments on commit 10691e9

Please sign in to comment.