Skip to content

Commit

Permalink
Refactorize NN Components: Setup -> Runner
Browse files Browse the repository at this point in the history
Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Jun 13, 2023
1 parent c1e01da commit d7dbae3
Show file tree
Hide file tree
Showing 5 changed files with 398 additions and 397 deletions.
134 changes: 53 additions & 81 deletions python/fate/components/components/homo_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@
)
import os
import pandas as pd
from typing import Literal
from fate.interface import Context
from fate.components.components.nn.setup.fate_setup import FateSetup
from fate.components.components.nn.nn_setup import NNRunner, SetupReturn, ComponentInputData
from fate.components.components.nn.runner.fate_runner import FateRunner
from fate.components.components.nn.nn_runner import NNRunner, NNInput, NNOutput
from fate.components.components.nn.loader import Loader
from fate.arch.dataframe._dataframe import DataFrame
import logging
Expand All @@ -39,25 +38,28 @@
logger = logging.getLogger(__name__)


FATE_TEST_PATH = '/home/cwj/FATE/playground/test_output_path'


def is_path(s):
return os.path.exists(s)


def prepare_setup_class(setup_module, setup_class, setup_conf, source):
print('setup conf is {}'.format(setup_conf))
def prepare_runner_class(runner_module, runner_class, runner_conf, source):
print('runner conf is {}'.format(runner_conf))
print('source is {}'.format(source))
if setup_module != 'fate_setup':
if runner_module != 'fate_runner':
if source == None:
# load from default folder
setup = Loader('fate.components.components.nn.setup.' + setup_module, setup_class, **setup_conf)()
runner = Loader('fate.components.components.nn.runner.' + runner_module, runner_class, **runner_conf)()
else:
setup = Loader(setup_module, setup_class, source=source, **setup_conf)()
assert isinstance(setup, NNRunner), 'loaded class must be a subclass of NNSetup class, but got {}'.format(type(setup))
runner = Loader(runner_module, runner_class, source=source, **runner_conf)()
assert isinstance(runner, NNRunner), 'loaded class must be a subclass of NNRunner class, but got {}'.format(type(runner))
else:
print('using default fate setup')
setup = FateSetup(**setup_conf)
print('using default fate runner')
runner = FateRunner(**runner_conf)

return setup
return runner


def transform_input_dataframe(sub_ctx, data):
Expand All @@ -67,60 +69,40 @@ def transform_input_dataframe(sub_ctx, data):
return df


def prepare_context_and_role(setup, ctx, role, sub_ctx_name):
def prepare_context_and_role(runner, ctx, role, sub_ctx_name):
with ctx.sub_ctx(sub_ctx_name) as sub_ctx:
# set context
setup.set_context(sub_ctx)
setup.set_role(role)
runner.set_context(sub_ctx)
runner.set_role(role)
return sub_ctx


def process_setup(setup, stage=Literal['train', 'predict']):
setup_ret = setup.setup(stage)
print('setup class is {}'.format(setup))
if not isinstance(setup_ret, SetupReturn):
raise ValueError(f'The return of your setup class must be a SetupReturn Instance, but got {setup_ret}')
return setup_ret


def handle_client(setup, sub_ctx, stage, cpn_input_data):
def get_input_data(sub_ctx, stage, cpn_input_data):
if stage == 'train':
train_data, validate_data = cpn_input_data
train_data = transform_input_dataframe(sub_ctx, train_data)
if validate_data is not None:
validate_data = transform_input_dataframe(sub_ctx, validate_data)
setup.set_cpn_input_data(ComponentInputData(train_data, validate_data))
return NNInput(train_data=train_data, validate_data=validate_data)
elif stage == 'predict':
test_data = cpn_input_data
test_data = transform_input_dataframe(sub_ctx, test_data)
setup.set_cpn_input_data(ComponentInputData(test_data=test_data))
return NNInput(test_data=test_data)
else:
raise ValueError(f'Unknown stage {stage}')
setup_ret = process_setup(setup, stage)
return setup_ret


def handle_server(setup, stage):
return process_setup(setup, stage)['trainer']


def update_output_dir(trainer):

FATE_TEST_PATH = '/home/cwj/FATE/playground/test_output_path'
# default trainer
trainer.args.output_dir = FATE_TEST_PATH


def model_output(setup_module,
setup_class,
setup_conf,
def model_output(runner_module,
runner_class,
runner_conf,
source,
model_output_path
):
return {
'setup_module': setup_module,
'setup_class': setup_class,
'setup_conf': setup_conf,
'runner_module': runner_module,
'runner_class': runner_class,
'runner_conf': runner_conf,
'source': source,
'model_output_path': model_output_path
}
Expand All @@ -134,10 +116,10 @@ def homo_nn(ctx, role):
@homo_nn.train()
@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[GUEST, HOST], desc="training data")
@cpn.artifact("validate_data", type=Input[DatasetArtifact], optional=True, roles=[GUEST, HOST], desc="validation data")
@cpn.parameter("setup_module", type=str, default='fate_setup', desc="name of your setup script")
@cpn.parameter("setup_class", type=str, default='FateSetup', desc="class name of your setup class")
@cpn.parameter("source", type=str, default=None, desc="path to your setup script folder")
@cpn.parameter("setup_conf", type=dict, default={}, desc="the parameter dict of the NN setup class")
@cpn.parameter("runner_module", type=str, default='fate_runner', desc="name of your runner script")
@cpn.parameter("runner_class", type=str, default='FateRunner', desc="class name of your runner class")
@cpn.parameter("source", type=str, default=None, desc="path to your runner script folder")
@cpn.parameter("runner_conf", type=dict, default={}, desc="the parameter dict of the NN runner class")
@cpn.artifact("train_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST])
@cpn.artifact("train_output_metric", type=Output[LossMetrics], roles=[ARBITER])
@cpn.artifact("output_model", type=Output[ModelArtifact], roles=[GUEST, HOST])
Expand All @@ -146,32 +128,28 @@ def train(
role: Role,
train_data,
validate_data,
setup_module,
setup_class,
setup_conf,
runner_module,
runner_class,
runner_conf,
source,
train_output_data,
train_output_metric,
output_model,
):

setup = prepare_setup_class(setup_module, setup_class, setup_conf, source)
sub_ctx = prepare_context_and_role(setup, ctx, role, "train")
runner: NNRunner = prepare_runner_class(runner_module, runner_class, runner_conf, source)
sub_ctx = prepare_context_and_role(runner, ctx, role, "train")

if role.is_guest or role.is_host: # is client
setup_ret = handle_client(setup, sub_ctx, 'train', [train_data, validate_data])
client_trainer = setup_ret['trainer']
update_output_dir(client_trainer) # update output dir
client_trainer.train()

output_conf = model_output(setup_module,
setup_class,
setup_conf,
source,
client_trainer.args.output_dir)
input_data = get_input_data(sub_ctx, 'train', [train_data, validate_data])
input_data.fate_save_path = FATE_TEST_PATH
ret = runner.train(input_data=input_data)

print('model output is {}'.format(output_conf))
client_trainer.save_model()
output_conf = model_output(runner_module,
runner_class,
runner_conf,
source,
FATE_TEST_PATH)
import json
path = '/home/cwj/FATE/playground/test_output_model/'
json.dump(output_conf, open(path + str(role.name) + '_conf.json', 'w'), indent=4)
Expand All @@ -180,8 +158,7 @@ def train(
model_writer.write_model("homo_nn", {}, metadata={})

elif role.is_arbiter: # is server
server_trainer = handle_server(setup, 'train')
server_trainer.train()
runner.train()


@homo_nn.predict()
Expand All @@ -201,21 +178,16 @@ def predict(
import json
path = '/home/cwj/FATE/playground/test_output_model/'
model_conf = json.load(open(path + str(role.name) + '_conf.json', 'r'))
setup_module = model_conf['setup_module']
setup_class = model_conf['setup_class']
setup_conf = model_conf['setup_conf']
runner_module = model_conf['runner_module']
runner_class = model_conf['runner_class']
runner_conf = model_conf['runner_conf']
source = model_conf['source']

setup = prepare_setup_class(setup_module, setup_class, setup_conf, source)
sub_ctx = prepare_context_and_role(setup, ctx, role, "predict")
setup_ret = handle_client(setup, sub_ctx, 'predict', test_data)
to_predict_dataset = setup_ret['test_set']
client_trainer = setup_ret['trainer']

if to_predict_dataset is None:
raise ValueError('The return of your setup class in the training stage must have "test_set" in the predict stage')

pred_rs = client_trainer.predict(to_predict_dataset)

runner: NNRunner = prepare_runner_class(runner_module, runner_class, runner_conf, source)
sub_ctx = prepare_context_and_role(runner, ctx, role, "predict")
input_data = get_input_data(sub_ctx, 'predict', test_data)
pred_rs = runner.predict(input_data)

print(f'predict result is {pred_rs}')

elif role.is_arbiter: # is server
Expand Down
99 changes: 99 additions & 0 deletions python/fate/components/components/nn/nn_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pandas as pd
from typing import Union, Type, Callable, Optional, List, Tuple
from fate.components import Role
from fate.interface import Context
from fate.ml.nn.trainer.trainer_base import FedTrainerClient, FedTrainerServer



class NNInput:
"""
Class to encapsulate input data for NN Runner.
Parameters:
train_data (Union[pd.DataFrame, str]): The training data as a pandas DataFrame or the file path to it.
validate_data (Union[pd.DataFrame, str]): The validation data as a pandas DataFrame or the file path to it.
test_data (Union[pd.DataFrame, str]): The testing data as a pandas DataFrame or the file path to it.
model_path (str): The path of a saved model.
fate_save_path (str): The path for you to save your trained model in current task.
"""

def __init__(self, train_data: Union[pd.DataFrame, str] = None,
validate_data: Union[pd.DataFrame, str] = None,
test_data: Union[pd.DataFrame, str] = None,
model_path: str = None,
fate_save_path: str = None
) -> None:
self.train_data = train_data
self.validate_data = validate_data
self.test_data = test_data
self.model_path = model_path
self.fate_save_path = fate_save_path

def get(self, key: str) -> Union[pd.DataFrame, str]:
return getattr(self, key)

def get_train_data(self) -> Union[pd.DataFrame, str]:
return self.train_data

def get_validate_data(self) -> Union[pd.DataFrame, str]:
return self.validate_data

def get_test_data(self) -> Union[pd.DataFrame, str]:
return self.test_data

def get_model_path(self) -> str:
return self.model_path

def get_fate_save_path(self) -> str:
return self.fate_save_path

def __repr__(self) -> str:
return f"NNInput(train_data={self.train_data}, validate_data={self.validate_data}, \
test_data={self.test_data}, model_path={self.model_path}, fate_save_path={self.fate_save_path})"


class NNOutput:

def __init__(self, data=None) -> None:
self.data = data


class NNRunner(object):

def __init__(self) -> None:

self._role = None
self._party_id = None
self._ctx: Context = None

def set_context(self, context: Context):
assert isinstance(context, Context)
self._ctx = context

def get_context(self) -> Context:
return self._ctx

def set_role(self, role: Role):
assert isinstance(role, Role)
self._role = role

def is_client(self) -> bool:
return self._role.is_guest or self._role.is_host

def is_server(self) -> bool:
return self._role.is_arbiter

def set_party_id(self, party_id: int):
assert isinstance(self._party_id, int)
self._party_id = party_id

def get_fateboard_tracker(self):
pass

def train(self, input_data: NNInput = None) -> Optional[Union[NNOutput, None]]:
pass

def predict(self, input_data: NNInput = None) -> Optional[Union[NNOutput, None]]:
pass

Loading

0 comments on commit d7dbae3

Please sign in to comment.