diff --git a/python/fate/components/components/homo_lr.py b/python/fate/components/components/homo_lr.py index 7dc567cc05..618b74c1d0 100644 --- a/python/fate/components/components/homo_lr.py +++ b/python/fate/components/components/homo_lr.py @@ -13,16 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import os - -import pandas as pd from fate.arch import Context -from fate.arch.dataframe import PandasReader from fate.ml.glm.homo_lr.client import HomoLRClient from fate.ml.glm.homo_lr.server import HomoLRServer -from fate.components.components.utils.predict_format import LABEL from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params from fate.components.components.utils import consts +from fate.ml.utils.model_io import ModelIO + logger = logging.getLogger(__name__) @@ -59,25 +56,40 @@ def train( desc="Model param init setting."), threshold: cpn.parameter(type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data"), + ovr: cpn.parameter(type=bool, default=False, + desc="predict threshold for binary data"), + label_num: cpn.parameter(type=params.conint(ge=2), default=None), train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), train_input_model: cpn.json_model_input(roles=[GUEST, HOST], optional=True), - output_model: cpn.json_model_output(roles=[GUEST, HOST]) + train_output_model: cpn.json_model_output(roles=[GUEST, HOST]) ): sub_ctx = ctx.sub_ctx(consts.TRAIN) if role.is_guest or role.is_host: # is client + logger.info('homo lr component: client start training') logger.info('optim param {} init param {}'.format(optimizer.dict(), init_param.dict())) - client = HomoLRClient(epochs=epochs, batch_size=batch_size, optimizer_param=optimizer, init_param=init_param, - learning_rate_scheduler=0.01, threshold=threshold) + client = HomoLRClient(epochs=epochs, batch_size=batch_size, optimizer_param=optimizer.dict(), init_param=init_param.dict(), + learning_rate_scheduler=0.01, threshold=threshold, ovr=ovr, label_num=label_num) train_df = train_data.read() validate_df = validate_data.read() if validate_data else None client.fit(sub_ctx, train_df, validate_df) - output_model.write({"aaa": 1}, metadata={"bbb": 2}) + model_dict = client.get_model().dict() + + train_rs = client.predict(sub_ctx, train_df) + if validate_df: + validate_rs = client.predict(sub_ctx, validate_df) + ret_df = train_rs.vstack(validate_rs) + else: + ret_df = train_rs + + train_output_data.write(ret_df) + train_output_model.write(model_dict, metadata=model_dict['meta']) + elif role.is_arbiter: # is server - logger.info('hello') + logger.info('homo lr component: server start training') server = HomoLRServer() server.fit(sub_ctx) @@ -97,7 +109,13 @@ def predict( ): if role.is_guest or role.is_host: # is client - pass + + client = HomoLRClient(batch_size=batch_size, threshold=threshold) + model_input = predict_input_model.read() + model_data = ModelIO.from_dict(model_input) + logger.info('model input is {}'.format(model_input)) + pred_rs = client.predict(ctx, test_data.read()) + elif role.is_arbiter: # is server logger.info("arbiter skip predict")