Skip to content

Commit

Permalink
rename lr cpn(#4659)
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Wu <yolandawu131@gmail.com>
  • Loading branch information
nemirorox committed Jun 30, 2023
1 parent 41ff683 commit 33bcc19
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 107 deletions.
2 changes: 1 addition & 1 deletion examples/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ parties: # parties default id

work_mode: 0 # 0 for standalone, or 1 for cluster

data_base_dir: "/data/projects/fate" # pa th to project base where data is located
data_base_dir: "/Users/yuwu/PycharmProjects/FATE" # pa th to project base where data is located
6 changes: 3 additions & 3 deletions python/fate/components/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def reader(self):
return reader

@_lazy_cpn
def hetero_lr(self):
from .hetero_lr import hetero_lr
def coordinated_lr(self):
from .coordinated_lr import coordinated_lr

return hetero_lr
return coordinated_lr

@_lazy_cpn
def homo_nn(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@


@cpn.component(roles=[GUEST, HOST, ARBITER])
def hetero_lr(ctx, role):
def coordinated_lr(ctx, role):
...


@hetero_lr.train()
@coordinated_lr.train()
def train(
ctx: Context,
role: Role,
train_data: cpn.dataframe_input(roles=[GUEST, HOST]),
validate_data: cpn.dataframe_input(roles=[GUEST, HOST]),
validate_data: cpn.dataframe_input(roles=[GUEST, HOST], optional=True),
learning_rate_scheduler: cpn.parameter(type=params.lr_scheduler_param(),
default=params.LRSchedulerParam(method="constant",
scheduler_params={"gamma": 0.1}),
Expand All @@ -54,7 +54,6 @@ def train(
threshold: cpn.parameter(type=params.confloat(ge=0.0, le=1.0), default=0.5,
desc="predict threshold for binary data"),
train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]),
train_output_metric: cpn.json_metric_output(roles=[ARBITER]),
output_model: cpn.json_model_output(roles=[GUEST, HOST]),
):
if role.is_guest:
Expand All @@ -68,11 +67,10 @@ def train(
batch_size, optimizer, learning_rate_scheduler, init_param
)
elif role.is_arbiter:
train_arbiter(ctx, max_iter, early_stop, tol, batch_size, optimizer, learning_rate_scheduler,
train_output_metric)
train_arbiter(ctx, max_iter, early_stop, tol, batch_size, optimizer, learning_rate_scheduler)


@hetero_lr.predict()
@coordinated_lr.predict()
def predict(
ctx,
role: Role,
Expand All @@ -87,7 +85,7 @@ def predict(
predict_host(ctx, input_model, test_data, test_output_data)


"""@hetero_lr.cross_validation()
"""@coordinated_lr.cross_validation()
def cross_validation(
ctx: Context,
role: Role,
Expand All @@ -104,9 +102,9 @@ def cross_validation(
# TODO: split data
for i, fold_ctx in cv_ctx.ctxs_range(num_fold):
if role.is_guest:
from fate.ml.glm.hetero_lr import HeteroLrModuleGuest
from fate.ml.glm.coordinated_lr import CoordinatedLRModuleGuest
module = HeteroLrModuleGuest(max_iter=max_iter, learning_rate=learning_rate, batch_size=batch_size)
module = CoordinatedLRModuleGuest(max_iter=max_iter, learning_rate=learning_rate, batch_size=batch_size)
train_data, validate_data = split_dataframe(data, num_fold, i)
module.fit(fold_ctx, train_data)
predicted = module.predict(fold_ctx, validate_data)
Expand All @@ -120,89 +118,82 @@ def cross_validation(

def train_guest(ctx, train_data, validate_data, train_output_data, output_model, max_iter,
batch_size, optimizer_param, learning_rate_param, init_param, threshold):
from fate.ml.glm.hetero_lr import HeteroLrModuleGuest
from fate.ml.glm.coordinated_lr import CoordinatedLRModuleGuest
# optimizer = optimizer_factory(optimizer_param)

with ctx.sub_ctx("train") as sub_ctx:
module = HeteroLrModuleGuest(max_iter=max_iter, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
init_param=init_param, threshold=threshold)
module = CoordinatedLRModuleGuest(max_iter=max_iter, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
init_param=init_param, threshold=threshold)
train_data = train_data.read()

if validate_data is not None:
validate_data = validate_data.read()

module.fit(sub_ctx, train_data, validate_data)
model = module.get_model()
with output_model as model_writer:
model_writer.write_model("hetero_lr_guest", model, metadata={"threshold": threshold})
output_model.write(model, metadata={"threshold": threshold})

with ctx.sub_ctx("predict") as sub_ctx:
predict_score = module.predict(sub_ctx, validate_data)
predict_result = transform_to_predict_result(validate_data, predict_score, module.labels,
threshold=module.threshold, is_ovr=module.ovr,
data_type="test")
sub_ctx.writer(train_output_data).write_dataframe(predict_result)
train_output_data.write(predict_result)


def train_host(ctx, train_data, validate_data, train_output_data, output_model, max_iter, batch_size,
optimizer_param, learning_rate_param, init_param):
from fate.ml.glm.hetero_lr import HeteroLrModuleHost
from fate.ml.glm.coordinated_lr import CoordinatedLRModuleHost

with ctx.sub_ctx("train") as sub_ctx:
module = HeteroLrModuleHost(max_iter=max_iter, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
init_param=init_param)
module = CoordinatedLRModuleHost(max_iter=max_iter, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param,
init_param=init_param)
train_data = train_data.read()

if validate_data is not None:
validate_data = validate_data.read()

module.fit(sub_ctx, train_data, validate_data)
model = module.get_model()
with output_model as model_writer:
model_writer.write_model("hetero_lr_host", model, metadata={})
output_model.write(model, metadata={})
with ctx.sub_ctx("predict") as sub_ctx:
module.predict(sub_ctx, validate_data)


def train_arbiter(ctx, max_iter, early_stop, tol, batch_size, optimizer_param, learning_rate_scheduler,
train_output_metric):
from fate.ml.glm.hetero_lr import HeteroLrModuleArbiter

ctx.metrics.handler.register_metrics(lr_loss=ctx.writer(train_output_metric))
def train_arbiter(ctx, max_iter, early_stop, tol, batch_size, optimizer_param, learning_rate_scheduler):
from fate.ml.glm.coordinated_lr import CoordinatedLRModuleArbiter

with ctx.sub_ctx("train") as sub_ctx:
module = HeteroLrModuleArbiter(max_iter=max_iter, early_stop=early_stop, tol=tol, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_scheduler)
module = CoordinatedLRModuleArbiter(max_iter=max_iter, early_stop=early_stop, tol=tol, batch_size=batch_size,
optimizer_param=optimizer_param,
learning_rate_param=learning_rate_scheduler)
module.fit(sub_ctx)


def predict_guest(ctx, input_model, test_data, test_output_data):
from fate.ml.glm.hetero_lr import HeteroLrModuleGuest
from fate.ml.glm.coordinated_lr import CoordinatedLRModuleGuest

with ctx.sub_ctx("predict") as sub_ctx:
with input_model as model_reader:
model = model_reader.read_model()

module = HeteroLrModuleGuest.from_model(model)
model = input_model.read()
module = CoordinatedLRModuleGuest.from_model(model)
# if module.threshold != 0.5:
# module.threshold = threshold
test_data = test_data.read()
predict_score = module.predict(sub_ctx, test_data)
predict_result = transform_to_predict_result(test_data, predict_score, module.labels,
threshold=module.threshold, is_ovr=module.ovr,
data_type="test")
sub_ctx.writer(test_output_data).write_dataframe(predict_result)
test_output_data.write(predict_result)


def predict_host(ctx, input_model, test_data, test_output_data):
from fate.ml.glm.hetero_lr import HeteroLrModuleHost
from fate.ml.glm.coordinated_lr import CoordinatedLRModuleHost

with ctx.sub_ctx("predict") as sub_ctx:
with input_model as model_reader:
model = model_reader.read_model()
module = HeteroLrModuleHost.from_model(model)
model = input_model.read()
module = CoordinatedLRModuleHost.from_model(model)
test_data = test_data.read()
module.predict(sub_ctx, test_data)

Expand Down
14 changes: 5 additions & 9 deletions python/fate/components/components/hetero_linr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def train(
ctx: Context,
role: Role,
train_data: cpn.dataframe_input(roles=[GUEST, HOST]),
validate_data: cpn.dataframe_input(roles=[GUEST, HOST]),
validate_data: cpn.dataframe_input(roles=[GUEST, HOST], optional=True),
learning_rate_scheduler: cpn.parameter(type=params.lr_scheduler_param(),
default=params.LRSchedulerParam(method="constant"),
desc="learning rate scheduler, "
Expand All @@ -51,7 +51,6 @@ def train(
default=params.InitParam(method='zeros', fit_intercept=True),
desc="Model param init setting."),
train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]),
train_output_metric: cpn.json_metric_output(roles=[ARBITER]),
output_model: cpn.json_model_output(roles=[GUEST, HOST]),
):
if role.is_guest:
Expand All @@ -65,8 +64,7 @@ def train(
batch_size, optimizer, learning_rate_scheduler, init_param
)
elif role.is_arbiter:
train_arbiter(ctx, max_iter, early_stop, tol, batch_size, optimizer, learning_rate_scheduler,
train_output_metric)
train_arbiter(ctx, max_iter, early_stop, tol, batch_size, optimizer, learning_rate_scheduler)


@hetero_linr.predict()
Expand Down Expand Up @@ -98,12 +96,12 @@ def train_guest(ctx, train_data, validate_data, train_output_data, output_model,
module.fit(sub_ctx, train_data, validate_data)
model = module.get_model()
with output_model as model_writer:
model_writer.write_model("hetero_linr_guest", model, metadata={})
model_writer.write_model("hetero_linr_host", model, metadata={})

with ctx.sub_ctx("predict") as sub_ctx:
predict_score = module.predict(sub_ctx, validate_data)
predict_result = transform_to_predict_result(validate_data, predict_score, data_type="train")
sub_ctx.writer(train_output_data).write_dataframe(predict_result)
train_output_data.write(predict_result)


def train_host(ctx, train_data, validate_data, train_output_data, output_model, max_iter, batch_size,
Expand All @@ -127,11 +125,9 @@ def train_host(ctx, train_data, validate_data, train_output_data, output_model,


def train_arbiter(ctx, max_iter, early_stop, tol, batch_size, optimizer_param,
learning_rate_param, train_output_metric):
learning_rate_param):
from fate.ml.glm.hetero_linr import HeteroLinRModuleArbiter

ctx.metrics.handler.register_metrics(linr_loss=ctx.writer(train_output_metric))

with ctx.sub_ctx("train") as sub_ctx:
module = HeteroLinRModuleArbiter(max_iter=max_iter, early_stop=early_stop, tol=tol, batch_size=batch_size,
optimizer_param=optimizer_param, learning_rate_param=learning_rate_param)
Expand Down
4 changes: 2 additions & 2 deletions python/fate/components/components/intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def raw_intersect_guest(ctx, input_data, output_data):
data = input_data.read()
guest_intersect_obj = RawIntersectionGuest()
intersect_data = guest_intersect_obj.fit(ctx, data)
ctx.writer(output_data).write_dataframe(intersect_data)
output_data.write(intersect_data)


def raw_intersect_host(ctx, input_data, output_data):
Expand All @@ -46,4 +46,4 @@ def raw_intersect_host(ctx, input_data, output_data):
data = input_data.read()
host_intersect_obj = RawIntersectionHost()
intersect_data = host_intersect_obj.fit(ctx, data)
ctx.writer(output_data).write_dataframe(intersect_data)
output_data.write(intersect_data)
2 changes: 1 addition & 1 deletion python/fate/ml/glm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .coordinated_lr import CoordinatedLRModuleHost, CoordinatedLRModuleGuest, CoordinatedLRModuleArbiter
from .hetero_linr import HeteroLinRModuleHost, HeteroLinRModuleGuest, HeteroLinRModuleArbiter
from .hetero_lr import HeteroLrModuleHost, HeteroLrModuleGuest, HeteroLrModuleArbiter
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .arbiter import HeteroLrModuleArbiter
from .guest import HeteroLrModuleGuest
from .host import HeteroLrModuleHost
from .arbiter import CoordinatedLRModuleArbiter
from .guest import CoordinatedLRModuleGuest
from .host import CoordinatedLRModuleHost
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
logger = logging.getLogger(__name__)


class HeteroLrModuleArbiter(HeteroModule):
class CoordinatedLRModuleArbiter(HeteroModule):
def __init__(
self,
max_iter,
Expand Down Expand Up @@ -71,20 +71,20 @@ def fit(self, ctx: Context) -> None:
for i, class_ctx in ctx.range(range(label_count)):
optimizer = copy.deepcopy(self.optimizer)
lr_scheduler = copy.deepcopy(self.lr_scheduler)
single_estimator = HeteroLrEstimatorArbiter(max_iter=self.max_iter,
early_stop=self.early_stop,
tol=self.tol,
batch_size=self.batch_size,
optimizer=optimizer,
learning_rate_scheduler=lr_scheduler)
single_estimator = CoordinatedLREstimatorArbiter(max_iter=self.max_iter,
early_stop=self.early_stop,
tol=self.tol,
batch_size=self.batch_size,
optimizer=optimizer,
learning_rate_scheduler=lr_scheduler)
single_estimator.fit_single_model(class_ctx, decryptor)
self.estimator[i] = single_estimator
else:
single_estimator = HeteroLrEstimatorArbiter(max_iter=self.max_iter,
early_stop=self.early_stop,
tol=self.tol,
batch_size=self.batch_size,
optimizer=self.optimizer)
single_estimator = CoordinatedLREstimatorArbiter(max_iter=self.max_iter,
early_stop=self.early_stop,
tol=self.tol,
batch_size=self.batch_size,
optimizer=self.optimizer)
single_estimator.fit_single_model(ctx, decryptor)
self.estimator = single_estimator

Expand All @@ -101,21 +101,21 @@ def to_model(self):
}

def from_model(cls, model):
lr = HeteroLrModuleArbiter(**model["metadata"])
lr = CoordinatedLRModuleArbiter(**model["metadata"])
all_estimator = model["estimator"]
if lr.ovr:
lr.estimator = {
label: HeteroLrEstimatorArbiter().restore(d) for label, d in all_estimator.items()
label: CoordinatedLREstimatorArbiter().restore(d) for label, d in all_estimator.items()
}
else:
estimator = HeteroLrEstimatorArbiter()
estimator = CoordinatedLREstimatorArbiter()
estimator.restore(all_estimator)
lr.estimator = estimator
return lr
return lr


class HeteroLrEstimatorArbiter(HeteroModule):
class CoordinatedLREstimatorArbiter(HeteroModule):
def __init__(
self,
max_iter=None,
Expand Down
Loading

0 comments on commit 33bcc19

Please sign in to comment.