Skip to content

Commit

Permalink
Merge pull request #4958 from FederatedAI/dev-2.0.0-homo-nn-predict-fix
Browse files Browse the repository at this point in the history
Dev 2.0.0 homo nn predict fix
  • Loading branch information
mgqa34 authored Jun 30, 2023
2 parents cc23d9f + d490a21 commit 69a39b0
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
7 changes: 4 additions & 3 deletions python/fate/components/components/homo_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def handle_nn_output(ctx, nn_output: NNOutput, output_class, stage):
logger.warning("train output is not NNOutput, but {}, fail to output dataframe".format(type(nn_output)))


def warmstart_prepare(model_conf, runner_class, runner_module, runner_conf, source):
def prepared_saved_conf(model_conf, runner_class, runner_module, runner_conf, source):

logger.info("loaded model_conf is: {}".format(model_conf))
if "saved_model_path" in model_conf:
Expand Down Expand Up @@ -196,7 +196,7 @@ def train(
saved_model_path=None
if train_model_input is not None:
model_conf = train_model_input.get_metadata()
runner_conf, source, runner_class, runner_module, saved_model_path = warmstart_prepare(model_conf, runner_class, runner_module, runner_conf, source)
runner_conf, source, runner_class, runner_module, saved_model_path = prepared_saved_conf(model_conf, runner_class, runner_module, runner_conf, source)

output_path = train_model_output.get_directory()
input_data = get_input_data(consts.TRAIN, [train_data, validate_data], output_path, saved_model_path)
Expand Down Expand Up @@ -231,10 +231,11 @@ def predict(
runner_class = model_conf['runner_class']
runner_conf = model_conf['runner_conf']
source = model_conf['source']
saved_model_path = model_conf["saved_model_path"]

runner: NNRunner = prepare_runner_class(runner_module, runner_class, runner_conf, source)
sub_ctx = prepare_context_and_role(runner, ctx, role, consts.PREDICT)
input_data = get_input_data(consts.PREDICT, test_data)
input_data = get_input_data(consts.PREDICT, test_data, saved_model_path=saved_model_path)
ret: NNOutput = runner.predict(input_data)
handle_nn_output(sub_ctx, ret, predict_data_output, consts.PREDICT)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def setup(self, cpn_input_data: NNInput, stage='train'):
model_path = cpn_input_data.get_saved_model_path()
# resume_from checkpoint path
resume_path = None

if model_path is not None:
model_dict = load_model_dict_from_path(model_path)
model.load_state_dict(model_dict)
Expand Down

0 comments on commit 69a39b0

Please sign in to comment.