Skip to content

Commit

Permalink
fix lr & linr export model param & ovr training(#4659)
Browse files Browse the repository at this point in the history
lr & linr use predict utils(#4659)

Signed-off-by: Yu Wu <yolandawu131@gmail.com>
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
nemirorox authored and sagewe committed Jul 21, 2023
1 parent a792198 commit 4e21827
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 42 deletions.
38 changes: 23 additions & 15 deletions python/fate/components/components/coordinated_linr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging

from fate.arch import Context
from fate.arch.dataframe import DataFrame
from fate.components.components.utils import consts, tools
from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params
from fate.ml.glm import CoordinatedLinRModuleArbiter, CoordinatedLinRModuleGuest, CoordinatedLinRModuleHost

Expand Down Expand Up @@ -176,15 +176,18 @@ def cross_validation(
module.fit(fold_ctx, train_data, validate_data)
if output_cv_data:
sub_ctx = fold_ctx.sub_ctx("predict_train")
predict_score = module.predict(sub_ctx, train_data)
train_predict_result = transform_to_predict_result(
train_predict_df = module.predict(sub_ctx, train_data)
"""train_predict_result = transform_to_predict_result(
train_data, predict_score, data_type="train"
)
)"""
train_predict_result = tools.add_dataset_type(train_predict_df, consts.TRAIN_SET)
sub_ctx = fold_ctx.sub_ctx("predict_validate")
predict_score = module.predict(sub_ctx, validate_data)
validate_predict_result = transform_to_predict_result(
validate_predict_df = module.predict(sub_ctx, validate_data)
validate_predict_result = tools.add_dataset_type(validate_predict_df, consts.VALIDATE_SET)
"""validate_predict_result = transform_to_predict_result(
validate_data, predict_score, data_type="predict"
)
"""
predict_result = DataFrame.vstack([train_predict_result, validate_predict_result])
next(cv_output_datas).write(df=predict_result)

Expand Down Expand Up @@ -230,14 +233,18 @@ def train_guest(ctx, train_data, validate_data, train_output_data, output_model,

sub_ctx = ctx.sub_ctx("predict")

predict_score = module.predict(sub_ctx, train_data)
predict_result = transform_to_predict_result(train_data, predict_score,
data_type="train")
predict_df = module.predict(sub_ctx, train_data)
"""predict_result = transform_to_predict_result(train_data, predict_score,
data_type="train")"""
predict_result = tools.add_dataset_type(predict_df, consts.TRAIN_SET)
if validate_data is not None:
sub_ctx = ctx.sub_ctx("validate_predict")
predict_score = module.predict(sub_ctx, validate_data)
validate_predict_result = transform_to_predict_result(validate_data, predict_score,
predict_df = module.predict(sub_ctx, validate_data)
validate_predict_result = tools.add_dataset_type(predict_df, consts.VALIDATE_SET)

"""validate_predict_result = transform_to_predict_result(validate_data, predict_score,
data_type="validate")
"""
predict_result = DataFrame.vstack([predict_result, validate_predict_result])
train_output_data.write(predict_result)

Expand Down Expand Up @@ -300,8 +307,9 @@ def predict_guest(ctx, input_model, test_data, test_output_data):

module = CoordinatedLinRModuleGuest.from_model(model)
test_data = test_data.read()
predict_score = module.predict(sub_ctx, test_data)
predict_result = transform_to_predict_result(test_data, predict_score, data_type="predict")
predict_result = module.predict(sub_ctx, test_data)
predict_result = tools.add_dataset_type(predict_result, consts.TEST_SET)
# predict_result = transform_to_predict_result(test_data, predict_score, data_type="predict")
test_output_data.write(predict_result)


Expand All @@ -313,7 +321,7 @@ def predict_host(ctx, input_model, test_data, test_output_data):
module.predict(sub_ctx, test_data)


def transform_to_predict_result(test_data, predict_score, data_type="test"):
"""def transform_to_predict_result(test_data, predict_score, data_type="test"):
df = test_data.create_frame(with_label=True, with_weight=False)
pred_res = test_data.create_frame(with_label=False, with_weight=False)
pred_res["predict_result"] = predict_score
Expand All @@ -322,4 +330,4 @@ def transform_to_predict_result(test_data, predict_score, data_type="test"):
v[0],
json.dumps({"label": v[0]}),
data_type], enable_type_align_checking=False)
return df
return df"""
41 changes: 23 additions & 18 deletions python/fate/components/components/coordinated_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging

from fate.arch import Context
from fate.arch.dataframe import DataFrame
from fate.components.components.utils import consts, tools
from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params
from fate.ml.glm import CoordinatedLRModuleGuest, CoordinatedLRModuleHost, CoordinatedLRModuleArbiter

Expand Down Expand Up @@ -217,17 +217,19 @@ def cross_validation(
module.fit(fold_ctx, train_data, validate_data)
if output_cv_data:
sub_ctx = fold_ctx.sub_ctx("predict_train")
predict_score = module.predict(sub_ctx, train_data)
train_predict_result = transform_to_predict_result(
predict_df = module.predict(sub_ctx, train_data)
"""train_predict_result = transform_to_predict_result(
train_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr,
data_type="train"
)
)"""
train_predict_result = tools.add_dataset_type(predict_df, consts.TRAIN_SET)
sub_ctx = fold_ctx.sub_ctx("predict_validate")
predict_score = module.predict(sub_ctx, validate_data)
validate_predict_result = transform_to_predict_result(
predict_df = module.predict(sub_ctx, validate_data)
"""validate_predict_result = transform_to_predict_result(
validate_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr,
data_type="predict"
)
)"""
validate_predict_result = tools.add_dataset_type(predict_df, consts.VALIDATE_SET)
predict_result = DataFrame.vstack([train_predict_result, validate_predict_result])
next(cv_output_datas).write(df=predict_result)

Expand Down Expand Up @@ -294,21 +296,23 @@ def train_guest(

sub_ctx = ctx.sub_ctx("predict")

predict_score = module.predict(sub_ctx, train_data)
predict_result = transform_to_predict_result(
predict_df = module.predict(sub_ctx, train_data)
"""predict_result = transform_to_predict_result(
train_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, data_type="train"
)
)"""
predict_result = tools.add_dataset_type(predict_df, consts.TRAIN_SET)
if validate_data is not None:
sub_ctx = ctx.sub_ctx("validate_predict")
predict_score = module.predict(sub_ctx, validate_data)
validate_predict_result = transform_to_predict_result(
predict_df = module.predict(sub_ctx, validate_data)
"""validate_predict_result = transform_to_predict_result(
validate_data,
predict_score,
module.labels,
threshold=module.threshold,
is_ovr=module.ovr,
data_type="validate",
)
)"""
validate_predict_result = tools.add_dataset_type(predict_df, consts.VALIDATE_SET)
predict_result = DataFrame.vstack([predict_result, validate_predict_result])
train_output_data.write(predict_result)

Expand Down Expand Up @@ -390,10 +394,11 @@ def predict_guest(ctx, input_model, test_data, test_output_data):
# if module.threshold != 0.5:
# module.threshold = threshold
test_data = test_data.read()
predict_score = module.predict(sub_ctx, test_data)
predict_result = transform_to_predict_result(
predict_df = module.predict(sub_ctx, test_data)
"""predict_result = transform_to_predict_result(
test_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, data_type="test"
)
)"""
predict_result = tools.add_dataset_type(predict_df, consts.TEST_SET)
test_output_data.write(predict_result)


Expand All @@ -406,7 +411,7 @@ def predict_host(ctx, input_model, test_data, test_output_data):
module.predict(sub_ctx, test_data)


def transform_to_predict_result(test_data, predict_score, labels, threshold=0.5, is_ovr=False, data_type="test"):
"""def transform_to_predict_result(test_data, predict_score, labels, threshold=0.5, is_ovr=False, data_type="test"):
if is_ovr:
df = test_data.create_frame(with_label=True, with_weight=False)
df[["predict_result", "predict_score", "predict_detail", "type"]] = predict_score.apply_row(
Expand All @@ -424,4 +429,4 @@ def transform_to_predict_result(test_data, predict_score, labels, threshold=0.5,
lambda v: [int(v[0] > threshold), v[0], json.dumps({1: v[0], 0: 1 - v[0]}), data_type],
enable_type_align_checking=False,
)
return df
return df"""
6 changes: 5 additions & 1 deletion python/fate/ml/glm/hetero/coordinated_linr/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from fate.arch import dataframe, Context
from fate.ml.abc.module import HeteroModule
from fate.ml.utils import predict_tools
from fate.ml.utils._model_param import initialize_param, serialize_param, deserialize_param
from fate.ml.utils._optimizer import Optimizer, LRScheduler

Expand Down Expand Up @@ -192,7 +193,10 @@ def predict(self, ctx, test_data):
pred = torch.matmul(X, self.w)
for h_pred in ctx.hosts.get("h_pred"):
pred += h_pred
return pred
pred_df = test_data.create_frame(with_label=True, with_weight=False)
pred_df[predict_tools.PREDICT_SCORE] = pred
predict_result = predict_tools.compute_predict_details(pred_df, task_type=predict_tools.REGRESSION)
return predict_result

def get_model(self):
"""w = self.w.tolist()
Expand Down
29 changes: 23 additions & 6 deletions python/fate/ml/glm/hetero/coordinated_lr/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from fate.arch import Context, dataframe
from fate.ml.abc.module import HeteroModule
from fate.ml.utils import predict_tools
from fate.ml.utils._model_param import initialize_param, serialize_param, deserialize_param
from fate.ml.utils._optimizer import LRScheduler, Optimizer

Expand Down Expand Up @@ -104,8 +105,12 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None:
single_estimator = self.estimator[i]
single_estimator.epochs = self.epochs
single_estimator.batch_size = self.batch_size
train_data.label = train_data_binarized_label[train_data_binarized_label.columns[i]]
single_estimator.fit_single_model(class_ctx, train_data, validate_data)
class_train_data = train_data.copy()
class_validate_data = validate_data
if validate_data:
class_validate_data = validate_data.copy()
class_train_data.label = train_data_binarized_label[train_data_binarized_label.columns[i]]
single_estimator.fit_single_model(class_ctx, class_train_data, class_validate_data)
self.estimator[i] = single_estimator

else:
Expand Down Expand Up @@ -136,14 +141,26 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None:

def predict(self, ctx, test_data):
if self.ovr:
predict_score = test_data.create_frame(with_label=False, with_weight=False)
pred_score = test_data.create_frame(with_label=False, with_weight=False)
for i, class_ctx in ctx.sub_ctx("class").ctxs_range(len(self.labels)):
estimator = self.estimator[i]
pred = estimator.predict(class_ctx, test_data)
predict_score[self.labels[i]] = pred
pred_score[self.labels[i]] = pred
pred_df = test_data.create_frame(with_label=True, with_weight=False)
pred_df[predict_tools.PREDICT_SCORE] = pred_score.apply_row(lambda v: [list(v)])
predict_result = predict_tools.compute_predict_details(pred_df,
task_type=predict_tools.MULTI,
classes=self.labels)
else:
predict_score = self.estimator.predict(ctx, test_data)
return predict_score
pred_df = test_data.create_frame(with_label=True, with_weight=False)
pred_df[predict_tools.PREDICT_SCORE] = predict_score
predict_result = predict_tools.compute_predict_details(pred_df,
task_type=predict_tools.BINARY,
classes=self.labels,
threshold=self.threshold)

return predict_result

def get_model(self):
all_estimator = {}
Expand Down Expand Up @@ -292,7 +309,7 @@ def predict(self, ctx, test_data):
if self.init_param.get("fit_intercept"):
test_data["intercept"] = 1.0
X = test_data.values.as_tensor()
logger.info(f"in predict, w: {self.w}")
# logger.info(f"in predict, w: {self.w}")
pred = torch.matmul(X, self.w)
for h_pred in ctx.hosts.get("h_pred"):
pred += h_pred
Expand Down
2 changes: 1 addition & 1 deletion python/fate/ml/glm/hetero/coordinated_lr/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def set_epochs(self, epochs):
def fit(self, ctx: Context, train_data, validate_data=None) -> None:
encryptor = ctx.arbiter("encryptor").get()
label_count = ctx.guest("label_count").get()
if self.label_count > 2 or self.ovr:
if label_count > 2 or self.ovr:
self.ovr = True
self.label_count = label_count
warm_start = True
Expand Down
2 changes: 1 addition & 1 deletion python/fate/ml/utils/_model_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def serialize_param(param, fit_intercept=False):
w = param.tolist()
intercept = None
if fit_intercept:
w = w[:-1]
intercept = w[-1]
w = w[:-1]
return {"coef_": w, "intercept_": intercept, "dtype": dtype}


Expand Down

0 comments on commit 4e21827

Please sign in to comment.