From 79ce2c4a993b7eb708f1045b2e8e0cb858d7a76b Mon Sep 17 00:00:00 2001 From: weiwee Date: Wed, 23 Nov 2022 01:48:09 -0800 Subject: [PATCH] feat(ml): lr support predict Signed-off-by: weiwee --- python/fate/arch/context/_io.py | 60 ++++++++++++------------- python/fate/components/components/lr.py | 32 ++++++------- python/fate/components/cpn.py | 24 +++++----- python/fate/ml/lr/guest.py | 17 +++++-- python/fate/ml/lr/host.py | 29 +++++++++++- schemas/tasks/lr.predict.guest.yaml | 55 +++++++++++++++++++++++ schemas/tasks/lr.predict.host.yaml | 53 ++++++++++++++++++++++ schemas/tasks/lr.predict.yaml | 56 ----------------------- 8 files changed, 206 insertions(+), 120 deletions(-) create mode 100644 schemas/tasks/lr.predict.guest.yaml create mode 100644 schemas/tasks/lr.predict.host.yaml delete mode 100644 schemas/tasks/lr.predict.yaml diff --git a/python/fate/arch/context/_io.py b/python/fate/arch/context/_io.py index 07b1ac8499..2b556198f2 100644 --- a/python/fate/arch/context/_io.py +++ b/python/fate/arch/context/_io.py @@ -1,39 +1,8 @@ -from typing import Protocol, overload +from typing import Protocol from ..unify import URI -class Reader: - @overload - def __init__(self, ctx, uri: str, **kwargs): - ... - - @overload - def __init__(self, ctx, data, **kwargs): - ... - - def __init__(self, ctx, *args, **kwargs): - self.ctx = ctx - if isinstance(args[0], str): - self.uri = args[0] - self.name = kwargs.get("name", "") - self.metadata = kwargs.get("metadata", {}) - elif hasattr(args[0], "uri"): - self.uri = args[0].uri - self.name = args[0].name - self.metadata = args[0].metadata - else: - raise ValueError(f"invalid arguments: {args} and {kwargs}") - - def read_dataframe(self): - from fate.arch import dataframe - - self.data = dataframe.CSVReader( - id_name="id", label_name="y", label_type="float32", delimiter=",", dtype="float32" - ).to_frame(self.ctx, self.uri) - return self - - class IOKit: @staticmethod def _parse_args(arg, **kwargs): @@ -81,6 +50,11 @@ def get_reader(format, ctx, name, uri, metadata) -> Reader: if format == "csv": return CSVReader(ctx, name, uri.path, metadata) + if format == "json": + return JsonReader(ctx, name, uri.path, metadata) + + raise ValueError(f"reader for format {format} not found") + class CSVReader: def __init__(self, ctx, name: str, uri: URI, metadata: dict) -> None: @@ -105,6 +79,26 @@ def read_dataframe(self): return DataframeReader(dataframe_reader, self.metadata["num_features"], self.metadata["num_samples"]) +class JsonReader: + def __init__(self, ctx, name: str, uri, metadata: dict) -> None: + self.name = name + self.ctx = ctx + self.uri = uri + self.metadata = metadata + + def read_model(self): + import json + + with open(self.uri, "r") as f: + return json.load(f) + + def read_metric(self): + import json + + with open(self.uri, "r") as f: + return json.load(f) + + class DataframeReader: def __init__(self, frames, num_features, num_samples) -> None: self.data = frames @@ -133,6 +127,8 @@ def get_writer(format, ctx, name, uri, metadata) -> Reader: if format == "json": return JsonWriter(ctx, name, uri.path, metadata) + raise ValueError(f"wirter for format {format} not found") + class Writer(Protocol): ... diff --git a/python/fate/components/components/lr.py b/python/fate/components/components/lr.py index 333613963f..e9570e5ed9 100644 --- a/python/fate/components/components/lr.py +++ b/python/fate/components/components/lr.py @@ -26,15 +26,15 @@ @cpn.component(roles=roles.get_all(), provider="fate", version="2.0.0.alpha") @cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[roles.GUEST, roles.HOST], stages=[stages.TRAIN]) @cpn.artifact( - "validate_data", type=Input[DatasetArtifact], optional=True, roles=[roles.GUEST, roles.HOST], stages=["train"] + "validate_data", type=Input[DatasetArtifact], optional=True, roles=[roles.GUEST, roles.HOST], stages=[stages.TRAIN] ) @cpn.artifact("input_model", type=Input[ModelArtifact], roles=[roles.GUEST, roles.HOST], stages=[stages.PREDICT]) @cpn.artifact( - "test_data", type=Input[DatasetArtifacts], optional=False, roles=[roles.GUEST, roles.HOST], stages=[stages.PREDICT] + "test_data", type=Input[DatasetArtifact], optional=False, roles=[roles.GUEST, roles.HOST], stages=[stages.PREDICT] ) -@cpn.parameter("learning_rate", type=float, default=0.1, optional=False) -@cpn.parameter("max_iter", type=int, default=100, optional=False) -@cpn.parameter("batch_size", type=int, default=100, optional=False) +@cpn.parameter("learning_rate", type=float, default=0.1) +@cpn.parameter("max_iter", type=int, default=100) +@cpn.parameter("batch_size", type=int, default=100) @cpn.artifact( "train_output_data", type=Output[DatasetArtifact], roles=[roles.GUEST, roles.HOST], stages=[stages.TRAIN] ) @@ -122,18 +122,20 @@ def train_arbiter(ctx, max_iter, train_output_metric): def predict_guest(ctx, input_model, test_data, test_output_data): from fate.ml.lr.guest import LrModuleGuest - model = ctx.read(input_model).load_model() - module = LrModuleGuest.from_model(model) - test_data = ctx.read(test_data).load_dataframe() - output_data = module.predict(ctx, test_data) - ctx.write(test_output_data).save_dataframe(output_data) + with ctx.sub_ctx("predict") as sub_ctx: + model = sub_ctx.reader(input_model).read_model() + module = LrModuleGuest.from_model(model) + test_data = sub_ctx.reader(test_data).read_dataframe() + output_data = module.predict(sub_ctx, test_data) + sub_ctx.writer(test_output_data).write_dataframe(output_data) def predict_host(ctx, input_model, test_data, test_output_data): from fate.ml.lr.host import LrModuleHost - model = ctx.read(input_model).load_model() - module = LrModuleHost.from_model(model) - test_data = ctx.read(test_data).load_dataframe() - output_data = module.predict(ctx, test_data) - ctx.write(test_output_data).save_dataframe(output_data) + with ctx.sub_ctx("predict") as sub_ctx: + model = sub_ctx.reader(input_model).read_model() + module = LrModuleHost.from_model(model) + test_data = sub_ctx.reader(test_data).read_dataframe() + output_data = module.predict(sub_ctx, test_data) + sub_ctx.writer(test_output_data).write_dataframe(output_data) diff --git a/python/fate/components/cpn.py b/python/fate/components/cpn.py index ea56a44af5..26c03ad3d2 100644 --- a/python/fate/components/cpn.py +++ b/python/fate/components/cpn.py @@ -163,20 +163,20 @@ def validate_and_extract_execute_args(self, config): # arg support to be parameter elif parameter := name_parameter_mapping.get(arg): - if not parameter.optional: - parameter_apply = config.inputs.parameters.get(arg) - if parameter_apply is None: + parameter_apply = config.inputs.parameters.get(arg) + if parameter_apply is None: + if not parameter.optional: raise ComponentApplyError(f"parameter `{arg}` required, declare: `{parameter}`") else: - if type(parameter_apply) != parameter.type: - raise ComponentApplyError( - f"parameter `{arg}` with applying config `{parameter_apply}` can't apply to `{parameter}`" - f": {type(parameter_apply)} != {parameter.type}" - ) - else: - execute_args.append(parameter_apply) + execute_args.append(parameter_apply) else: - execute_args.append(parameter.default) + if type(parameter_apply) != parameter.type: + raise ComponentApplyError( + f"parameter `{arg}` with applying config `{parameter_apply}` can't apply to `{parameter}`" + f": {type(parameter_apply)} != {parameter.type}" + ) + else: + execute_args.append(parameter_apply) else: raise ComponentApplyError(f"should not go here") @@ -365,7 +365,7 @@ def __str__(self) -> str: return f"Parameter" -def parameter(name, type, default=None, optional=False, desc=None): +def parameter(name, type, default=None, optional=True, desc=None): """attaches an parameter to the component.""" def decorator(f): diff --git a/python/fate/ml/lr/guest.py b/python/fate/ml/lr/guest.py index c170ce2639..609ba3861c 100644 --- a/python/fate/ml/lr/guest.py +++ b/python/fate/ml/lr/guest.py @@ -69,7 +69,6 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None: self.w = w def predict(self, ctx, test_data): - logger.info(test_data) batch_loader = dataframe.DataLoader( test_data, ctx=ctx, @@ -83,8 +82,20 @@ def predict(self, ctx, test_data): print(output) def get_model(self): - return {"w": self.w.to_local()._storage.data.tolist()} + return { + "w": self.w.to_local()._storage.data.tolist(), + "metadata": { + "max_iter": self.max_iter, + "batch_size": self.batch_size, + "learning_rate": self.learning_rate, + "alpha": self.alpha, + }, + } @classmethod def from_model(cls, model) -> "LrModuleGuest": - ... + lr = LrModuleGuest(**model["metadata"]) + import torch + + lr.w = tensor.tensor(torch.tensor(model["w"])) + return lr diff --git a/python/fate/ml/lr/host.py b/python/fate/ml/lr/host.py index a0bd3f7b09..cb524f33a6 100644 --- a/python/fate/ml/lr/host.py +++ b/python/fate/ml/lr/host.py @@ -51,8 +51,33 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None: self.w = w def get_model(self): - return {"w": self.w.to_local()._storage.data.tolist()} + return { + "w": self.w.to_local()._storage.data.tolist(), + "metadata": { + "max_iter": self.max_iter, + "batch_size": self.batch_size, + "learning_rate": self.learning_rate, + "alpha": self.alpha, + }, + } + + def predict(self, ctx, test_data): + batch_loader = DataLoader( + test_data, + ctx=ctx, + batch_size=-1, + mode="hetero", + role="host", + sync_arbiter=False, + ) + for X in batch_loader: + output = tensor.matmul(X, self.w) + print(output) @classmethod def from_model(cls, model) -> "LrModuleHost": - ... + lr = LrModuleHost(**model["metadata"]) + import torch + + lr.w = tensor.tensor(torch.tensor(model["w"])) + return lr diff --git a/schemas/tasks/lr.predict.guest.yaml b/schemas/tasks/lr.predict.guest.yaml new file mode 100644 index 0000000000..310ebb92f4 --- /dev/null +++ b/schemas/tasks/lr.predict.guest.yaml @@ -0,0 +1,55 @@ +execution_id: predict_xxx +component: hetero_lr +role: guest +stage: predict +inputs: + artifacts: + test_data: + name: predict_data + uri: file:///Users/sage/proj/FATE/2.0.0-alpha/examples/data/breast_hetero_guest.csv + metadata: + format: csv + id_name: id + delimiter: "," + label_name: y + label_type: float32 + dtype: float32 + num_features: 10 + num_samples: 569 + input_model: + name: input_model + uri: file:///Users/sage/alpha/model/guest.json + metadata: + format: json +outputs: + artifacts: + test_output_data: + name: test_output_data + uri: file:///Users/sage/alpha/data/guest_predict_output.csv + metadata: + format: csv +conf: + mlmd: + type: pipeline + metadata: + db: /Users/sage/alpha/mlmd.db + logger: + type: pipeline + metadata: + basepath: /Users/sage/alpha/logs + level: DEBUG + debug_mode: true + device: CPU + computing: + engine: standalone + computing_id: xxxx + federation: + engine: standalone + federation_id: xxx + parties: + local: + role: guest + partyid: "9999" + parties: + - role: host + partyid: "10000" diff --git a/schemas/tasks/lr.predict.host.yaml b/schemas/tasks/lr.predict.host.yaml new file mode 100644 index 0000000000..a36818a5e0 --- /dev/null +++ b/schemas/tasks/lr.predict.host.yaml @@ -0,0 +1,53 @@ +execution_id: predict_xxx +component: hetero_lr +role: host +stage: predict +inputs: + artifacts: + test_data: + name: predict_data + uri: file:///Users/sage/proj/FATE/2.0.0-alpha/examples/data/breast_hetero_host.csv + metadata: + format: csv + id_name: id + delimiter: "," + dtype: float32 + num_features: 20 + num_samples: 569 + input_model: + name: input_model + uri: file:///Users/sage/alpha/model/host.json + metadata: + format: json +outputs: + artifacts: + test_output_data: + name: test_output_data + uri: file:///Users/sage/alpha/data/host_predict_output.csv + metadata: + format: csv +conf: + mlmd: + type: pipeline + metadata: + db: /Users/sage/alpha/mlmd.db + logger: + type: pipeline + metadata: + basepath: /Users/sage/alpha/logs + level: DEBUG + debug_mode: true + device: CPU + computing: + engine: standalone + computing_id: xxxx + federation: + engine: standalone + federation_id: xxx + parties: + local: + role: host + partyid: "10000" + parties: + - role: guest + partyid: "9999" diff --git a/schemas/tasks/lr.predict.yaml b/schemas/tasks/lr.predict.yaml deleted file mode 100644 index abe5478bdc..0000000000 --- a/schemas/tasks/lr.predict.yaml +++ /dev/null @@ -1,56 +0,0 @@ -execution_id: xxx -component: hetero_lr -role: guest -stage: predict -inputs: - parameters: - learning_rate: 0.2 - max_iter: 100 - artifacts: - test_data: - - name: test_data1 - uri: file:///tmp/test_data1 - metadata: - format: csv - - name: test_data2 - uri: file:///tmp/test_data2 - metadata: - format: csv - input_model: - name: input_model - uri: file:///tmp/trained_model -outputs: - artifacts: - test_output_data: - name: train_output_data - uri: file:///tmp/train_data_forward - metadata: - format: csv -env: - mlmd: - type: pipeline - metadata: - state_path: /Users/sage/fate_tmp/state - terminate_state_path: /Users/sage/fate_tmp/terminate_state - logger: - type: pipeline - metadata: - basepath: /Users/sage/fate_tmp/logs - level: DEBUG - debug_mode: true - device: CPU - distributed_computing_backend: - engine: standalone - computing_id: xxxx - federation_backend: - engine: standalone - federation_id: xxx - parties: - local: - role: guest - partyid: "9999" - parties: - - role: host - partyid: "10000" - - role: arbiter - partyid: "10001"