Skip to content

Commit

Permalink
feat(ml): lr support predict
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Nov 23, 2022
1 parent a83be47 commit 79ce2c4
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 120 deletions.
60 changes: 28 additions & 32 deletions python/fate/arch/context/_io.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,8 @@
from typing import Protocol, overload
from typing import Protocol

from ..unify import URI


class Reader:
@overload
def __init__(self, ctx, uri: str, **kwargs):
...

@overload
def __init__(self, ctx, data, **kwargs):
...

def __init__(self, ctx, *args, **kwargs):
self.ctx = ctx
if isinstance(args[0], str):
self.uri = args[0]
self.name = kwargs.get("name", "")
self.metadata = kwargs.get("metadata", {})
elif hasattr(args[0], "uri"):
self.uri = args[0].uri
self.name = args[0].name
self.metadata = args[0].metadata
else:
raise ValueError(f"invalid arguments: {args} and {kwargs}")

def read_dataframe(self):
from fate.arch import dataframe

self.data = dataframe.CSVReader(
id_name="id", label_name="y", label_type="float32", delimiter=",", dtype="float32"
).to_frame(self.ctx, self.uri)
return self


class IOKit:
@staticmethod
def _parse_args(arg, **kwargs):
Expand Down Expand Up @@ -81,6 +50,11 @@ def get_reader(format, ctx, name, uri, metadata) -> Reader:
if format == "csv":
return CSVReader(ctx, name, uri.path, metadata)

if format == "json":
return JsonReader(ctx, name, uri.path, metadata)

raise ValueError(f"reader for format {format} not found")


class CSVReader:
def __init__(self, ctx, name: str, uri: URI, metadata: dict) -> None:
Expand All @@ -105,6 +79,26 @@ def read_dataframe(self):
return DataframeReader(dataframe_reader, self.metadata["num_features"], self.metadata["num_samples"])


class JsonReader:
def __init__(self, ctx, name: str, uri, metadata: dict) -> None:
self.name = name
self.ctx = ctx
self.uri = uri
self.metadata = metadata

def read_model(self):
import json

with open(self.uri, "r") as f:
return json.load(f)

def read_metric(self):
import json

with open(self.uri, "r") as f:
return json.load(f)


class DataframeReader:
def __init__(self, frames, num_features, num_samples) -> None:
self.data = frames
Expand Down Expand Up @@ -133,6 +127,8 @@ def get_writer(format, ctx, name, uri, metadata) -> Reader:
if format == "json":
return JsonWriter(ctx, name, uri.path, metadata)

raise ValueError(f"wirter for format {format} not found")


class Writer(Protocol):
...
Expand Down
32 changes: 17 additions & 15 deletions python/fate/components/components/lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
@cpn.component(roles=roles.get_all(), provider="fate", version="2.0.0.alpha")
@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[roles.GUEST, roles.HOST], stages=[stages.TRAIN])
@cpn.artifact(
"validate_data", type=Input[DatasetArtifact], optional=True, roles=[roles.GUEST, roles.HOST], stages=["train"]
"validate_data", type=Input[DatasetArtifact], optional=True, roles=[roles.GUEST, roles.HOST], stages=[stages.TRAIN]
)
@cpn.artifact("input_model", type=Input[ModelArtifact], roles=[roles.GUEST, roles.HOST], stages=[stages.PREDICT])
@cpn.artifact(
"test_data", type=Input[DatasetArtifacts], optional=False, roles=[roles.GUEST, roles.HOST], stages=[stages.PREDICT]
"test_data", type=Input[DatasetArtifact], optional=False, roles=[roles.GUEST, roles.HOST], stages=[stages.PREDICT]
)
@cpn.parameter("learning_rate", type=float, default=0.1, optional=False)
@cpn.parameter("max_iter", type=int, default=100, optional=False)
@cpn.parameter("batch_size", type=int, default=100, optional=False)
@cpn.parameter("learning_rate", type=float, default=0.1)
@cpn.parameter("max_iter", type=int, default=100)
@cpn.parameter("batch_size", type=int, default=100)
@cpn.artifact(
"train_output_data", type=Output[DatasetArtifact], roles=[roles.GUEST, roles.HOST], stages=[stages.TRAIN]
)
Expand Down Expand Up @@ -122,18 +122,20 @@ def train_arbiter(ctx, max_iter, train_output_metric):
def predict_guest(ctx, input_model, test_data, test_output_data):
from fate.ml.lr.guest import LrModuleGuest

model = ctx.read(input_model).load_model()
module = LrModuleGuest.from_model(model)
test_data = ctx.read(test_data).load_dataframe()
output_data = module.predict(ctx, test_data)
ctx.write(test_output_data).save_dataframe(output_data)
with ctx.sub_ctx("predict") as sub_ctx:
model = sub_ctx.reader(input_model).read_model()
module = LrModuleGuest.from_model(model)
test_data = sub_ctx.reader(test_data).read_dataframe()
output_data = module.predict(sub_ctx, test_data)
sub_ctx.writer(test_output_data).write_dataframe(output_data)


def predict_host(ctx, input_model, test_data, test_output_data):
from fate.ml.lr.host import LrModuleHost

model = ctx.read(input_model).load_model()
module = LrModuleHost.from_model(model)
test_data = ctx.read(test_data).load_dataframe()
output_data = module.predict(ctx, test_data)
ctx.write(test_output_data).save_dataframe(output_data)
with ctx.sub_ctx("predict") as sub_ctx:
model = sub_ctx.reader(input_model).read_model()
module = LrModuleHost.from_model(model)
test_data = sub_ctx.reader(test_data).read_dataframe()
output_data = module.predict(sub_ctx, test_data)
sub_ctx.writer(test_output_data).write_dataframe(output_data)
24 changes: 12 additions & 12 deletions python/fate/components/cpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,20 +163,20 @@ def validate_and_extract_execute_args(self, config):

# arg support to be parameter
elif parameter := name_parameter_mapping.get(arg):
if not parameter.optional:
parameter_apply = config.inputs.parameters.get(arg)
if parameter_apply is None:
parameter_apply = config.inputs.parameters.get(arg)
if parameter_apply is None:
if not parameter.optional:
raise ComponentApplyError(f"parameter `{arg}` required, declare: `{parameter}`")
else:
if type(parameter_apply) != parameter.type:
raise ComponentApplyError(
f"parameter `{arg}` with applying config `{parameter_apply}` can't apply to `{parameter}`"
f": {type(parameter_apply)} != {parameter.type}"
)
else:
execute_args.append(parameter_apply)
execute_args.append(parameter_apply)
else:
execute_args.append(parameter.default)
if type(parameter_apply) != parameter.type:
raise ComponentApplyError(
f"parameter `{arg}` with applying config `{parameter_apply}` can't apply to `{parameter}`"
f": {type(parameter_apply)} != {parameter.type}"
)
else:
execute_args.append(parameter_apply)
else:
raise ComponentApplyError(f"should not go here")

Expand Down Expand Up @@ -365,7 +365,7 @@ def __str__(self) -> str:
return f"Parameter<name={self.name}, type={self.type}, default={self.default}, optional={self.optional}>"


def parameter(name, type, default=None, optional=False, desc=None):
def parameter(name, type, default=None, optional=True, desc=None):
"""attaches an parameter to the component."""

def decorator(f):
Expand Down
17 changes: 14 additions & 3 deletions python/fate/ml/lr/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None:
self.w = w

def predict(self, ctx, test_data):
logger.info(test_data)
batch_loader = dataframe.DataLoader(
test_data,
ctx=ctx,
Expand All @@ -83,8 +82,20 @@ def predict(self, ctx, test_data):
print(output)

def get_model(self):
return {"w": self.w.to_local()._storage.data.tolist()}
return {
"w": self.w.to_local()._storage.data.tolist(),
"metadata": {
"max_iter": self.max_iter,
"batch_size": self.batch_size,
"learning_rate": self.learning_rate,
"alpha": self.alpha,
},
}

@classmethod
def from_model(cls, model) -> "LrModuleGuest":
...
lr = LrModuleGuest(**model["metadata"])
import torch

lr.w = tensor.tensor(torch.tensor(model["w"]))
return lr
29 changes: 27 additions & 2 deletions python/fate/ml/lr/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,33 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None:
self.w = w

def get_model(self):
return {"w": self.w.to_local()._storage.data.tolist()}
return {
"w": self.w.to_local()._storage.data.tolist(),
"metadata": {
"max_iter": self.max_iter,
"batch_size": self.batch_size,
"learning_rate": self.learning_rate,
"alpha": self.alpha,
},
}

def predict(self, ctx, test_data):
batch_loader = DataLoader(
test_data,
ctx=ctx,
batch_size=-1,
mode="hetero",
role="host",
sync_arbiter=False,
)
for X in batch_loader:
output = tensor.matmul(X, self.w)
print(output)

@classmethod
def from_model(cls, model) -> "LrModuleHost":
...
lr = LrModuleHost(**model["metadata"])
import torch

lr.w = tensor.tensor(torch.tensor(model["w"]))
return lr
55 changes: 55 additions & 0 deletions schemas/tasks/lr.predict.guest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
execution_id: predict_xxx
component: hetero_lr
role: guest
stage: predict
inputs:
artifacts:
test_data:
name: predict_data
uri: file:///Users/sage/proj/FATE/2.0.0-alpha/examples/data/breast_hetero_guest.csv
metadata:
format: csv
id_name: id
delimiter: ","
label_name: y
label_type: float32
dtype: float32
num_features: 10
num_samples: 569
input_model:
name: input_model
uri: file:///Users/sage/alpha/model/guest.json
metadata:
format: json
outputs:
artifacts:
test_output_data:
name: test_output_data
uri: file:///Users/sage/alpha/data/guest_predict_output.csv
metadata:
format: csv
conf:
mlmd:
type: pipeline
metadata:
db: /Users/sage/alpha/mlmd.db
logger:
type: pipeline
metadata:
basepath: /Users/sage/alpha/logs
level: DEBUG
debug_mode: true
device: CPU
computing:
engine: standalone
computing_id: xxxx
federation:
engine: standalone
federation_id: xxx
parties:
local:
role: guest
partyid: "9999"
parties:
- role: host
partyid: "10000"
53 changes: 53 additions & 0 deletions schemas/tasks/lr.predict.host.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
execution_id: predict_xxx
component: hetero_lr
role: host
stage: predict
inputs:
artifacts:
test_data:
name: predict_data
uri: file:///Users/sage/proj/FATE/2.0.0-alpha/examples/data/breast_hetero_host.csv
metadata:
format: csv
id_name: id
delimiter: ","
dtype: float32
num_features: 20
num_samples: 569
input_model:
name: input_model
uri: file:///Users/sage/alpha/model/host.json
metadata:
format: json
outputs:
artifacts:
test_output_data:
name: test_output_data
uri: file:///Users/sage/alpha/data/host_predict_output.csv
metadata:
format: csv
conf:
mlmd:
type: pipeline
metadata:
db: /Users/sage/alpha/mlmd.db
logger:
type: pipeline
metadata:
basepath: /Users/sage/alpha/logs
level: DEBUG
debug_mode: true
device: CPU
computing:
engine: standalone
computing_id: xxxx
federation:
engine: standalone
federation_id: xxx
parties:
local:
role: host
partyid: "10000"
parties:
- role: guest
partyid: "9999"
Loading

0 comments on commit 79ce2c4

Please sign in to comment.