Skip to content

Commit

Permalink
NN adaption
Browse files Browse the repository at this point in the history
Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Jun 28, 2023
1 parent dbdf0f9 commit 2a42ad3
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 68 deletions.
1 change: 1 addition & 0 deletions python/fate/arch/_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
)
)

LOGGER.error(f"DATA_PATH: {_data_dir}")

# noinspection PyPep8Naming
class Table(object):
Expand Down
82 changes: 36 additions & 46 deletions python/fate/components/components/homo_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
logger = logging.getLogger(__name__)


FATE_TEST_PATH = '/home/cwj/FATE/playground/test_output_path'


def is_path(s):
return os.path.exists(s)

Expand Down Expand Up @@ -63,7 +60,7 @@ def prepare_context_and_role(runner, ctx, role, sub_ctx_name):
return sub_ctx


def get_input_data(stage, cpn_input_data, input_type='df'):
def get_input_data(stage, cpn_input_data, save_path, input_type='df',):
if stage == 'train':
train_data, validate_data = cpn_input_data
if input_type == 'df':
Expand Down Expand Up @@ -134,7 +131,7 @@ def handle_nn_output(ctx, nn_output: NNOutput, output_class, stage):
raise ValueError('test result not found in the NNOutput: {}'.format(nn_output))
write_output_df(ctx, nn_output.test_result, output_class, nn_output.match_id_name, nn_output.sample_id_name)
else:
logger.warning('train output is not NNOutput, but {}'.format(type(nn_output)))
logger.warning('train output is not NNOutput, but {}, fail to output dataframe'.format(type(nn_output)))


@cpn.component(roles=[GUEST, HOST, ARBITER])
Expand All @@ -152,63 +149,56 @@ def train(
runner_class: cpn.parameter(type=str, default='DefaultRunner', desc="class name of your runner class"),
runner_conf: cpn.parameter(type=dict, default={}, desc="the parameter dict of the NN runner class"),
source: cpn.parameter(type=str, default=None, desc="path to your runner script folder"),
dataframe_output: cpn.dataframe_output(roles=[GUEST, HOST]),
json_metric_output: cpn.json_metric_output(roles=[GUEST, HOST]),
model_directory_output: cpn.model_directory_output(roles=[GUEST, HOST]),
data_output: cpn.dataframe_output(roles=[GUEST, HOST]),
metric_output: cpn.json_metric_output(roles=[GUEST, HOST]),
model_output: cpn.model_directory_output(roles=[GUEST, HOST]),
):

runner: NNRunner = prepare_runner_class(runner_module, runner_class, runner_conf, source)
sub_ctx = prepare_context_and_role(runner, ctx, role, consts.TRAIN)

if role.is_guest or role.is_host: # is client

input_data = get_input_data(consts.TRAIN, [train_data, validate_data])
input_data.fate_save_path = FATE_TEST_PATH
output_path = model_output.get_directory()
input_data = get_input_data(consts.TRAIN, [train_data, validate_data], output_path)
ret: NNOutput = runner.train(input_data=input_data)
logger.info("train result: {}".format(ret))
handle_nn_output(sub_ctx, ret, dataframe_output, consts.TRAIN)

handle_nn_output(sub_ctx, ret, data_output, consts.TRAIN)
output_conf = model_output(runner_module,
runner_class,
runner_conf,
source,
FATE_TEST_PATH)
output_path = model_directory_output.get_directory()
output_path)
logger.info("output_path: {}".format(output_conf))
model_directory_output.write_metadata(output_conf)
json_metric_output.write({"train":1123})
model_output.write_metadata(output_conf)
metric_output.write({"nn_conf": output_conf})

elif role.is_arbiter: # is server
runner.train()


# @homo_nn.predict()
# @cpn.artifact("input_model", type=Input[ModelArtifact], roles=[GUEST, HOST])
# @cpn.artifact("test_data", type=Input[DatasetArtifact], optional=False, roles=[GUEST, HOST])
# @cpn.artifact("test_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST])
# def predict(
# ctx,
# role: Role,
# test_data,
# input_model,
# test_output_data,
# ):

# if role.is_guest or role.is_host: # is client

# import json
# path = '/home/cwj/FATE/playground/test_output_model/'
# model_conf = json.load(open(path + str(role.name) + '_conf.json', 'r'))
# runner_module = model_conf['runner_module']
# runner_class = model_conf['runner_class']
# runner_conf = model_conf['runner_conf']
# source = model_conf['source']

# runner: NNRunner = prepare_runner_class(runner_module, runner_class, runner_conf, source)
# sub_ctx = prepare_context_and_role(runner, ctx, role, consts.PREDICT)
# input_data = get_input_data(consts.PREDICT, test_data)
# pred_rs = runner.predict(input_data)
# handle_nn_output(sub_ctx, pred_rs, test_output_data, consts.PREDICT)

# elif role.is_arbiter: # is server
# logger.info('arbiter skip predict')
@homo_nn.predict()
def predict(
ctx,
role: Role,
test_data: cpn.dataframe_input(roles=[GUEST, HOST], optional=True),
model_input: cpn.model_directory_input(roles=[GUEST, HOST]),
data_output: cpn.dataframe_output(roles=[GUEST, HOST])
):

if role.is_guest or role.is_host: # is client

model_conf = model_input.get_metadata()
runner_module = model_conf['runner_module']
runner_class = model_conf['runner_class']
runner_conf = model_conf['runner_conf']
source = model_conf['source']

runner: NNRunner = prepare_runner_class(runner_module, runner_class, runner_conf, source)
sub_ctx = prepare_context_and_role(runner, ctx, role, consts.PREDICT)
input_data = get_input_data(consts.PREDICT, test_data)
ret: NNOutput = runner.predict(input_data)
handle_nn_output(sub_ctx, ret, data_output, consts.PREDICT)

elif role.is_arbiter: # is server
logger.info('arbiter skip predict')
26 changes: 11 additions & 15 deletions python/fate/components/components/nn/runner/default_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ def __init__(self,
self.tokenizer_conf = tokenizer_conf
self.task_type = task_type

self._resume_from_checkpoint = False

def _loader_load_from_conf(self, conf, return_class=False):
if conf is None:
return None
Expand Down Expand Up @@ -181,26 +179,23 @@ def setup(self, cpn_input_data: NNInput, stage='train'):

# load arguments, models, etc
# prepare datatset

# dataet
train_set = self._prepare_dataset(self.dataset_conf, cpn_input_data.get_train_data())
validate_set = self._prepare_dataset(self.dataset_conf, cpn_input_data.get_validate_data())
test_set = self._prepare_dataset(self.dataset_conf, cpn_input_data.get_test_data())

# load model
model = self._loader_load_from_conf(self.model_conf)

# save path: path to save provided by fate framework
save_path = cpn_input_data.get_fate_save_path()
# if have input model for warm-start
model_path = cpn_input_data.get_model_path()

# resume_from checkpoint path
resume_path = None
if model_path is not None:
path = cpn_input_data.get_model_path()
model_dict = load_model_dict_from_path(path)
model_dict = load_model_dict_from_path(model_path)
model.load_state_dict(model_dict)
if get_last_checkpoint(path) is not None:
self._resume_from_checkpoint = True
else:
model_path = './'
if get_last_checkpoint(model_path) is not None:
resume_path = model_path

# load optimizer
optimizer_loader = Loader.from_dict(self.optimizer_conf)
Expand All @@ -211,11 +206,12 @@ def setup(self, cpn_input_data: NNInput, stage='train'):
loss = self._loader_load_from_conf(self.loss_conf)
# load collator func
data_collator = self._loader_load_from_conf(self.data_collator_conf)
#
# load tokenizer if import conf provided
tokenizer = self._loader_load_from_conf(self.tokenizer_conf)
# args
training_args = TrainingArguments(**self.training_args_conf)
training_args.output_dir = model_path # reset to default, saving to arbitrary path is not allowed in NN component
training_args.output_dir = save_path # reset to default, saving to arbitrary path is not allowed in NN component
training_args.resume_from_checkpoint = resume_path # resume path
fed_args = FedAVGArguments(**self.fed_args_conf)

# prepare trainer
Expand All @@ -239,7 +235,7 @@ def train(self, input_data: NNInput = None) -> Optional[Union[NNOutput, None]]:
trainer = setup['trainer']
if self.is_client():

trainer.train(resume_from_checkpoint=self._resume_from_checkpoint)
trainer.train()
trainer.save_model(input_data.get('fate_save_path'))
# predict the dataset when training is done
train_rs = trainer.predict(setup['train_set']) if setup['train_set'] else None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional, Type

from .._base_type import Role, _create_artifact_annotation
from ._directory import ModelDirectoryArtifactDescribe
from ._directory import ModelDirectoryArtifactDescribe, ModelDirectoryReader, ModelDirectoryWriter
from ._json import JsonModelArtifactDescribe, JsonModelReader, JsonModelWriter


Expand All @@ -23,23 +23,23 @@ def json_model_outputs(roles: Optional[List[Role]] = None, desc="", optional=Fal

def model_directory_input(
roles: Optional[List[Role]] = None, desc="", optional=False
) -> Type[ModelDirectoryArtifactDescribe]:
) -> Type[ModelDirectoryReader]:
return _create_artifact_annotation(True, False, ModelDirectoryArtifactDescribe, "model")(roles, desc, optional)


def model_directory_inputs(
roles: Optional[List[Role]] = None, desc="", optional=False
) -> Type[List[ModelDirectoryArtifactDescribe]]:
) -> Type[List[ModelDirectoryReader]]:
return _create_artifact_annotation(True, True, ModelDirectoryArtifactDescribe, "model")(roles, desc, optional)


def model_directory_output(
roles: Optional[List[Role]] = None, desc="", optional=False
) -> Type[ModelDirectoryArtifactDescribe]:
) -> Type[ModelDirectoryWriter]:
return _create_artifact_annotation(False, False, ModelDirectoryArtifactDescribe, "model")(roles, desc, optional)


def model_directory_outputs(
roles: Optional[List[Role]] = None, desc="", optional=False
) -> Type[List[ModelDirectoryArtifactDescribe]]:
) -> Type[List[ModelDirectoryWriter]]:
return _create_artifact_annotation(False, True, ModelDirectoryArtifactDescribe, "model")(roles, desc, optional)

0 comments on commit 2a42ad3

Please sign in to comment.