Skip to content

Commit

Permalink
coordinated lr & linr add key length & floating point precision param(#…
Browse files Browse the repository at this point in the history
…4659)

binning add key length param(#4660)
edit fate test config(#5008)
fix pipeline examples

Signed-off-by: Yu Wu <yolandawu131@gmail.com>
  • Loading branch information
nemirorox committed Sep 5, 2023
1 parent 52dc006 commit 8c83c97
Show file tree
Hide file tree
Showing 17 changed files with 224 additions and 91 deletions.
2 changes: 1 addition & 1 deletion examples/pipeline/coordinated_linr/test_linr.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def main(config="../config.yaml", namespace=""):
pipeline.add_task(linr_0)
pipeline.add_task(evaluation_0)
pipeline.compile()
print(pipeline.get_dag())
# print(pipeline.get_dag())
pipeline.fit()

pipeline.deploy([psi_0, linr_0])
Expand Down
2 changes: 1 addition & 1 deletion examples/pipeline/coordinated_linr/test_linr_multi_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def main(config="../config.yaml", namespace=""):
pipeline.add_task(evaluation_0)

pipeline.compile()
print(pipeline.get_dag())
# print(pipeline.get_dag())
pipeline.fit()

pipeline.deploy([psi_0, linr_0])
Expand Down
17 changes: 10 additions & 7 deletions examples/pipeline/multi_model/test_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from fate_client.pipeline import FateFlowPipeline
from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection, HeteroFeatureBinning, \
FeatureScale, Union, DataSplit, CoordinatedLR, Statistics, Sample, Evaluation
FeatureScale, Union, DataSplit, CoordinatedLR, CoordinatedLinR, Statistics, Sample, Evaluation
from fate_client.pipeline.interface import DataWarehouseChannel
from fate_client.pipeline.utils import test_utils

Expand Down Expand Up @@ -73,10 +73,11 @@ def main(config="../config.yaml", namespace=""):

lr_0 = CoordinatedLR("lr_0", train_data=selection_0.outputs["train_output_data"],
validate_data=selection_1.outputs["test_output_data"], epochs=3)
linr_0 = CoordinatedLR("linr_0", train_data=selection_0.outputs["train_output_data"],
validate_data=selection_1.outputs["test_output_data"], epochs=3)
linr_0 = CoordinatedLinR("linr_0", train_data=selection_0.outputs["train_output_data"],
validate_data=selection_1.outputs["test_output_data"], epochs=3)

evaluation_0 = Evaluation("evaluation_0", input_data=lr_0.outputs["train_output_data"],
default_eval_setting="binary",
runtime_roles=["guest"])
evaluation_1 = Evaluation("evaluation_1", input_data=linr_0.outputs["train_output_data"],
default_eval_setting="regression",
Expand Down Expand Up @@ -107,10 +108,12 @@ def main(config="../config.yaml", namespace=""):
predict_pipeline = FateFlowPipeline()

deployed_pipeline = pipeline.get_deployed_pipeline()
psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest",
namespace=f"experiment{namespace}"))
psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host",
namespace=f"experiment{namespace}"))
deployed_pipeline.psi_0.guest.component_setting(
input_data=DataWarehouseChannel(name="breast_hetero_guest",
namespace=f"experiment{namespace}"))
deployed_pipeline.psi_0.hosts[0].component_setting(
input_data=DataWarehouseChannel(name="breast_hetero_host",
namespace=f"experiment{namespace}"))

predict_pipeline.add_task(deployed_pipeline)
predict_pipeline.compile()
Expand Down
4 changes: 2 additions & 2 deletions python/fate/arch/tensor/distributed/_ops_others.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fate.arch.tensor import _custom_ops
import torch

from fate.arch.tensor import _custom_ops
from ._tensor import DTensor, implements


Expand All @@ -11,4 +11,4 @@ def to_local_f(input: DTensor):

@implements(_custom_ops.encode_as_int_f)
def encode_as_int_f(input: DTensor, precision):
return input.shardings.map_shard(lambda x: (x * 2**precision).astype(torch.int64), dtype=torch.int64)
return DTensor(input.shardings.map_shard(lambda x: (x * 2 ** precision).type(torch.int64), dtype=torch.int64))
39 changes: 27 additions & 12 deletions python/fate/components/components/coordinated_linr.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def train(
default=params.InitParam(method="random_uniform", fit_intercept=True, random_state=None),
desc="Model param init setting.",
),
key_length: cpn.parameter(type=params.conint(ge=0), default=1024, desc="key length"),
floating_point_precision: cpn.parameter(
type=params.conint(ge=0),
default=23,
desc="floating point precision, "),
train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]),
output_model: cpn.json_model_output(roles=[GUEST, HOST, ARBITER]),
warm_start_model: cpn.json_model_input(roles=[GUEST, HOST, ARBITER], optional=True),
Expand All @@ -73,16 +78,18 @@ def train(
if role.is_guest:
train_guest(
ctx, train_data, validate_data, train_output_data, output_model, epochs,
batch_size, optimizer, learning_rate_scheduler, init_param, warm_start_model
batch_size, optimizer, learning_rate_scheduler, init_param, floating_point_precision,
warm_start_model
)
elif role.is_host:
train_host(
ctx, train_data, validate_data, train_output_data, output_model, epochs,
batch_size, optimizer, learning_rate_scheduler, init_param, warm_start_model
batch_size, optimizer, learning_rate_scheduler, init_param, floating_point_precision,
warm_start_model
)
elif role.is_arbiter:
train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer, learning_rate_scheduler, output_model,
warm_start_model)
train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer, learning_rate_scheduler,
key_length, output_model, warm_start_model)


@coordinated_linr.predict()
Expand Down Expand Up @@ -137,6 +144,11 @@ def cross_validation(
cv_param: cpn.parameter(type=params.cv_param(),
default=params.CVParam(n_splits=5, shuffle=False, random_state=None),
desc="cross validation param"),
floating_point_precision: cpn.parameter(
type=params.conint(ge=0),
default=23,
desc="floating point precision, "),
key_length: cpn.parameter(type=params.conint(ge=0), default=1024, desc="key length"),
metrics: cpn.parameter(type=params.metrics_param(), default=["mse"]),
output_cv_data: cpn.parameter(type=bool, default=True, desc="whether output prediction result per cv fold"),
cv_output_datas: cpn.dataframe_outputs(roles=[GUEST, HOST], optional=True),
Expand All @@ -157,6 +169,7 @@ def cross_validation(
batch_size=batch_size,
optimizer_param=optimizer,
learning_rate_param=learning_rate_scheduler,
key_length=key_length,
)
module.fit(fold_ctx)
i += 1
Expand All @@ -173,7 +186,8 @@ def cross_validation(
batch_size=batch_size,
optimizer_param=optimizer,
learning_rate_param=learning_rate_scheduler,
init_param=init_param
init_param=init_param,
floating_point_precision=floating_point_precision
)
module.fit(fold_ctx, train_data, validate_data)
if output_cv_data:
Expand All @@ -200,7 +214,8 @@ def cross_validation(
batch_size=batch_size,
optimizer_param=optimizer,
learning_rate_param=learning_rate_scheduler,
init_param=init_param
init_param=init_param,
floating_point_precision=floating_point_precision
)
module.fit(fold_ctx, train_data, validate_data)
if output_cv_data:
Expand All @@ -212,7 +227,7 @@ def cross_validation(


def train_guest(ctx, train_data, validate_data, train_output_data, output_model, epochs,
batch_size, optimizer_param, learning_rate_param, init_param, input_model):
batch_size, optimizer_param, learning_rate_param, init_param, floating_point_precision, input_model):
if input_model is not None:
logger.info(f"warm start model provided")
model = input_model.read()
Expand All @@ -222,7 +237,7 @@ def train_guest(ctx, train_data, validate_data, train_output_data, output_model,
else:
module = CoordinatedLinRModuleGuest(epochs=epochs, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
init_param=init_param)
init_param=init_param, floating_point_precision=floating_point_precision)
logger.info(f"coordinated linr guest start train")
sub_ctx = ctx.sub_ctx("train")
train_data = train_data.read()
Expand Down Expand Up @@ -252,7 +267,7 @@ def train_guest(ctx, train_data, validate_data, train_output_data, output_model,


def train_host(ctx, train_data, validate_data, train_output_data, output_model, epochs, batch_size,
optimizer_param, learning_rate_param, init_param, input_model):
optimizer_param, learning_rate_param, init_param, floating_point_precision, input_model):
if input_model is not None:
logger.info(f"warm start model provided")
model = input_model.read()
Expand All @@ -262,7 +277,7 @@ def train_host(ctx, train_data, validate_data, train_output_data, output_model,
else:
module = CoordinatedLinRModuleHost(epochs=epochs, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
init_param=init_param)
init_param=init_param, floating_point_precision=floating_point_precision)
logger.info(f"coordinated linr host start train")
sub_ctx = ctx.sub_ctx("train")

Expand All @@ -282,7 +297,7 @@ def train_host(ctx, train_data, validate_data, train_output_data, output_model,


def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param,
learning_rate_param, output_model, input_model):
learning_rate_param, key_length, output_model, input_model):
if input_model is not None:
logger.info(f"warm start model provided")
model = input_model.read()
Expand All @@ -292,7 +307,7 @@ def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param,
else:
module = CoordinatedLinRModuleArbiter(epochs=epochs, early_stop=early_stop, tol=tol, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
)
key_length=key_length)
logger.info(f"coordinated linr arbiter start train")

sub_ctx = ctx.sub_ctx("train")
Expand Down
50 changes: 36 additions & 14 deletions python/fate/components/components/coordinated_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,27 @@ def train(
"refer to torch.optim.lr_scheduler",
),
epochs: cpn.parameter(type=params.conint(gt=0), default=20, desc="max iteration num"),
batch_size: cpn.parameter(
type=params.conint(ge=10),
default=None, desc="batch size, None means full batch, otherwise should be no less than 10, default None"
),
optimizer: cpn.parameter(
type=params.optimizer_param(),
default=params.OptimizerParam(
method="sgd", penalty="l2", alpha=1.0, optimizer_params={"lr": 1e-2, "weight_decay": 0}
batch_size: cpn.parameter(
type=params.conint(ge=10),
default=None, desc="batch size, None means full batch, otherwise should be no less than 10, default None"
),
),
optimizer: cpn.parameter(
type=params.optimizer_param(),
default=params.OptimizerParam(
method="sgd", penalty="l2", alpha=1.0, optimizer_params={"lr": 1e-2, "weight_decay": 0}
),
),
floating_point_precision: cpn.parameter(
type=params.conint(ge=0),
default=23,
desc="floating point precision, "),
tol: cpn.parameter(type=params.confloat(ge=0), default=1e-4),
early_stop: cpn.parameter(
type=params.string_choice(["weight_diff", "diff", "abs"]),
default="diff",
desc="early stopping criterion, choose from {weight_diff, diff, abs, val_metrics}",
),
key_length: cpn.parameter(type=params.conint(ge=0), default=1024, desc="key length"),
init_param: cpn.parameter(
type=params.init_param(),
default=params.InitParam(method="random_uniform", fit_intercept=True, random_state=None),
Expand Down Expand Up @@ -92,6 +97,7 @@ def train(
learning_rate_scheduler,
init_param,
threshold,
floating_point_precision,
warm_start_model
)
elif role.is_host:
Expand All @@ -106,6 +112,7 @@ def train(
optimizer,
learning_rate_scheduler,
init_param,
floating_point_precision,
warm_start_model
)
elif role.is_arbiter:
Expand All @@ -115,6 +122,7 @@ def train(
tol, batch_size,
optimizer,
learning_rate_scheduler,
key_length,
output_model,
warm_start_model)

Expand Down Expand Up @@ -172,6 +180,11 @@ def cross_validation(
threshold: cpn.parameter(
type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data"
),
key_length: cpn.parameter(type=params.conint(ge=0), default=1024, desc="key length"),
floating_point_precision: cpn.parameter(
type=params.conint(ge=0),
default=23,
desc="floating point precision, "),
cv_param: cpn.parameter(type=params.cv_param(),
default=params.CVParam(n_splits=5, shuffle=False, random_state=None),
desc="cross validation param"),
Expand All @@ -195,6 +208,7 @@ def cross_validation(
batch_size=batch_size,
optimizer_param=optimizer,
learning_rate_param=learning_rate_scheduler,
key_length=key_length,
)
module.fit(fold_ctx)
i += 1
Expand All @@ -213,6 +227,7 @@ def cross_validation(
learning_rate_param=learning_rate_scheduler,
init_param=init_param,
threshold=threshold,
floating_point_precision=floating_point_precision,
)
module.fit(fold_ctx, train_data, validate_data)
if output_cv_data:
Expand Down Expand Up @@ -241,6 +256,7 @@ def cross_validation(
optimizer_param=optimizer,
learning_rate_param=learning_rate_scheduler,
init_param=init_param,
floating_point_precision=floating_point_precision
)
module.fit(fold_ctx, train_data, validate_data)
if output_cv_data:
Expand All @@ -253,8 +269,8 @@ def cross_validation(

def train_guest(
ctx,
train_data,
validate_data,
train_data,
validate_data,
train_output_data,
output_model,
epochs,
Expand All @@ -263,6 +279,7 @@ def train_guest(
learning_rate_param,
init_param,
threshold,
floating_point_precision,
input_model
):
if input_model is not None:
Expand All @@ -280,6 +297,7 @@ def train_guest(
learning_rate_param=learning_rate_param,
init_param=init_param,
threshold=threshold,
floating_point_precision=floating_point_precision
)
# optimizer = optimizer_factory(optimizer_param)
logger.info(f"coordinated lr guest start train")
Expand Down Expand Up @@ -318,8 +336,8 @@ def train_guest(


def train_host(
ctx,
train_data,
ctx,
train_data,
validate_data,
train_output_data,
output_model,
Expand All @@ -328,6 +346,7 @@ def train_host(
optimizer_param,
learning_rate_param,
init_param,
floating_point_precision,
input_model
):
if input_model is not None:
Expand All @@ -343,6 +362,7 @@ def train_host(
optimizer_param=optimizer_param,
learning_rate_param=learning_rate_param,
init_param=init_param,
floating_point_precision=floating_point_precision
)
logger.info(f"coordinated lr host start train")
sub_ctx = ctx.sub_ctx("train")
Expand All @@ -362,7 +382,8 @@ def train_host(
module.predict(sub_ctx, validate_data)


def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_scheduler, output_model,
def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_scheduler, key_length,
output_model,
input_model):
if input_model is not None:
logger.info(f"warm start model provided")
Expand All @@ -378,6 +399,7 @@ def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, lea
batch_size=batch_size,
optimizer_param=optimizer_param,
learning_rate_param=learning_rate_scheduler,
key_length=key_length
)
logger.info(f"coordinated lr arbiter start train")
sub_ctx = ctx.sub_ctx("train")
Expand Down
Loading

0 comments on commit 8c83c97

Please sign in to comment.