Skip to content

Commit

Permalink
Update Homo-LR:
Browse files Browse the repository at this point in the history
1. None label predict
2. Training check
3. Predict support in component
Update FATE-ML
1. Add tools for predict result formatting

Signed-off-by: cwj <talkingwallace@sohu.com>
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
talkingwallace authored and sagewe committed Jul 21, 2023
1 parent f542bd9 commit 3e2c00e
Showing 1 changed file with 29 additions and 11 deletions.
40 changes: 29 additions & 11 deletions python/fate/components/components/homo_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os

import pandas as pd
from fate.arch import Context
from fate.arch.dataframe import PandasReader
from fate.ml.glm.homo_lr.client import HomoLRClient
from fate.ml.glm.homo_lr.server import HomoLRServer
from fate.components.components.utils.predict_format import LABEL
from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params
from fate.components.components.utils import consts
from fate.ml.utils.model_io import ModelIO


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,25 +56,40 @@ def train(
desc="Model param init setting."),
threshold: cpn.parameter(type=params.confloat(ge=0.0, le=1.0), default=0.5,
desc="predict threshold for binary data"),
ovr: cpn.parameter(type=bool, default=False,
desc="predict threshold for binary data"),
label_num: cpn.parameter(type=params.conint(ge=2), default=None),
train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]),
train_input_model: cpn.json_model_input(roles=[GUEST, HOST], optional=True),
output_model: cpn.json_model_output(roles=[GUEST, HOST])
train_output_model: cpn.json_model_output(roles=[GUEST, HOST])
):

sub_ctx = ctx.sub_ctx(consts.TRAIN)

if role.is_guest or role.is_host: # is client

logger.info('homo lr component: client start training')
logger.info('optim param {} init param {}'.format(optimizer.dict(), init_param.dict()))

client = HomoLRClient(epochs=epochs, batch_size=batch_size, optimizer_param=optimizer, init_param=init_param,
learning_rate_scheduler=0.01, threshold=threshold)
client = HomoLRClient(epochs=epochs, batch_size=batch_size, optimizer_param=optimizer.dict(), init_param=init_param.dict(),
learning_rate_scheduler=0.01, threshold=threshold, ovr=ovr, label_num=label_num)
train_df = train_data.read()
validate_df = validate_data.read() if validate_data else None
client.fit(sub_ctx, train_df, validate_df)
output_model.write({"aaa": 1}, metadata={"bbb": 2})
model_dict = client.get_model().dict()

train_rs = client.predict(sub_ctx, train_df)
if validate_df:
validate_rs = client.predict(sub_ctx, validate_df)
ret_df = train_rs.vstack(validate_rs)
else:
ret_df = train_rs

train_output_data.write(ret_df)
train_output_model.write(model_dict, metadata=model_dict['meta'])

elif role.is_arbiter: # is server
logger.info('hello')
logger.info('homo lr component: server start training')
server = HomoLRServer()
server.fit(sub_ctx)

Expand All @@ -97,7 +109,13 @@ def predict(
):

if role.is_guest or role.is_host: # is client
pass

client = HomoLRClient(batch_size=batch_size, threshold=threshold)
model_input = predict_input_model.read()
model_data = ModelIO.from_dict(model_input)
logger.info('model input is {}'.format(model_input))
pred_rs = client.predict(ctx, test_data.read())


elif role.is_arbiter: # is server
logger.info("arbiter skip predict")

0 comments on commit 3e2c00e

Please sign in to comment.