Skip to content

Commit

Permalink
fix(ml): add model saver
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Nov 21, 2022
1 parent 126ab65 commit 9a8b6cd
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 52 deletions.
55 changes: 50 additions & 5 deletions python/fate/arch/context/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def read_dataframe(self):


class IOKit:
def reader(self, ctx, arg, **kwargs):
@staticmethod
def _parse_args(arg, **kwargs):
name = ""
metadata = {}
if hasattr(arg, "uri"):
Expand All @@ -54,20 +55,31 @@ def reader(self, ctx, arg, **kwargs):
if k not in ["name", "metadata"]:
metadata[k] = v

self.uri = URI.from_string(uri)
uri = URI.from_string(uri)
format = metadata.get("format")
return format, name, uri, metadata

def reader(self, ctx, arg, **kwargs) -> "Reader":
format, name, uri, metadata = self._parse_args(arg, **kwargs)
if format is None:
raise ValueError(f"reader format `{format}` unknown")
return get_reader(format, ctx, name, uri, metadata)

def writer(self, ctx, arg, **kwargs) -> "Writer":
format, name, uri, metadata = self._parse_args(arg, **kwargs)
if format is None:
raise ValueError(f"reader format `{format}` unknown")
return get_writer(format, ctx, name, uri, metadata)


class Reader(Protocol):
...
def read_dataframe(self):
...


def get_reader(format, ctx, name, uri, metadata) -> Reader:
if format == "csv":
return CSVReader(ctx, name, uri, metadata)
return CSVReader(ctx, name, uri.path, metadata)


class CSVReader:
Expand Down Expand Up @@ -116,9 +128,42 @@ def read_dataframe(self):
...


class CSVWriter:
def get_writer(format, ctx, name, uri, metadata) -> Reader:
if format == "csv":
return CSVWriter(ctx, name, uri, metadata)

if format == "json":
return JsonWriter(ctx, name, uri.path, metadata)


class Writer(Protocol):
...


class CSVWriter:
def __init__(self, ctx, name: str, uri: URI, metadata: dict) -> None:
self.name = name
self.ctx = ctx
self.uri = uri
self.metadata = metadata

def write_dataframe(self, df):
...


class JsonWriter:
def __init__(self, ctx, name: str, uri, metadata: dict) -> None:
self.name = name
self.ctx = ctx
self.uri = uri
self.metadata = metadata

def write_model(self, model):
import json

with open(self.uri, "w") as f:
json.dump(model, f)


class LibSVMWriter:
...
2 changes: 1 addition & 1 deletion python/fate/components/spec/mlmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _log_state(self, state, message=None):
json.dump(data, f)

def safe_terminate(self):
...
return True


class FlowMLMD(BaseModel):
Expand Down
17 changes: 0 additions & 17 deletions python/fate/ml/lr/arbiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,7 @@
class LrModuleArbiter(HeteroModule):
def __init__(
self,
penalty="l2",
*,
dual=False,
tol=1e-4,
C=1.0,
fit_intercept=True,
intercept_scaling=1,
class_weight=None,
random_state=None,
solver="lbfgs",
max_iter=100,
multi_class="auto",
verbose=0,
warm_start=False,
n_jobs=None,
l1_ratio=None,
):
self.max_iter = max_iter
self.batch_size = 5
Expand All @@ -44,11 +29,9 @@ def fit(self, ctx: Context) -> None:
for batch_ctx, _ in iter_ctx.iter(batch_loader):
g_guest_enc = batch_ctx.guest.get("g_enc")
g = decryptor.decrypt(g_guest_enc)
logger.info(f"g={g}")
batch_ctx.guest.put("g", g)
for i, g_host_enc in enumerate(batch_ctx.hosts.get("g_enc")):
g = decryptor.decrypt(g_host_enc)
batch_ctx.hosts[i].put("g", g)
logger.info(f"g={g}")
loss = decryptor.decrypt(batch_ctx.guest.get("loss"))
logger.info(f"loss={loss}")
6 changes: 4 additions & 2 deletions python/fate/ml/lr/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(
self.learning_rate = learning_rate
self.alpha = alpha

self.w = None

def fit(self, ctx: Context, train_data, validate_data=None) -> None:
"""
l(w) = 1/h * Σ(log(2) - 0.5 * y * xw + 0.125 * (wx)^2)
Expand All @@ -41,7 +43,6 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None:
logger.info(f"start iter {i}")
j = 0
for batch_ctx, (X, Y) in iter_ctx.iter(batch_loader):
print(X, Y)
h = X.shape[0]

# d
Expand All @@ -65,9 +66,10 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None:
w -= (self.learning_rate / h) * g
logger.info(f"w={w}")
j += 1
self.w = w

def to_model(self):
...
return {"w": self.w.to_local()._storage.data.tolist()}

@classmethod
def from_model(cls, model) -> "LrModuleGuest":
Expand Down
29 changes: 8 additions & 21 deletions python/fate/ml/lr/host.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,15 @@
import logging

import numpy as np
import torch
from fate.arch import tensor
from fate.arch.dataframe import CSVReader, DataLoader
from fate.interface import Context, ModelsLoader, ModelsSaver
from pandas import pandas
from fate.arch.dataframe import DataLoader
from fate.interface import Context

from ..abc.module import HeteroModule

logger = logging.getLogger(__name__)


class DataframeMock:
def __init__(self, ctx) -> None:
guest_data_path = "/Users/sage/proj/FATE/2.0.0-alpha/" "examples/data/breast_hetero_host.csv"
self.data = CSVReader(id_name="id", delimiter=",", dtype="float32").to_frame(ctx, guest_data_path)
self.num_features = 20
self.num_sample = len(self.data)

def batches(self, batch_size):
num_batchs = (self.num_sample - 1) // batch_size + 1
for chunk in np.array_split(self.data, num_batchs):
yield tensor.tensor(torch.Tensor(chunk[:, 1:]))


class LrModuleHost(HeteroModule):
def __init__(
self,
Expand All @@ -38,10 +23,10 @@ def __init__(
self.alpha = alpha
self.batch_size = batch_size

self.w = None

def fit(self, ctx: Context, train_data, validate_data=None) -> None:
# mock data
train_data = DataframeMock(ctx)
batch_loader = DataLoader(train_data.data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="host")
batch_loader = DataLoader(train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="host")
# get encryptor
encryptor = ctx.arbiter("encryptor").get()

Expand All @@ -63,8 +48,10 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None:
logger.info(f"w={w}")
j += 1

self.w = w

def to_model(self):
...
return {"w": self.w.to_local()._storage.data.tolist()}

@classmethod
def from_model(cls, model) -> "LrModuleHost":
Expand Down
2 changes: 1 addition & 1 deletion schemas/tasks/lr.train.arbiter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ stage: train
inputs:
parameters:
learning_rate: 0.01
max_iter: 100
max_iter: 5
batch_size: 100

outputs:
Expand Down
6 changes: 4 additions & 2 deletions schemas/tasks/lr.train.guest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ stage: train
inputs:
parameters:
learning_rate: 0.01
max_iter: 100
max_iter: 5
batch_size: 100
artifacts:
train_data:
Expand All @@ -31,7 +31,9 @@ outputs:
artifacts:
output_model:
name: output_model
uri: file:///tmp/trained_model
uri: file:///Users/sage/proj/FATE/2.0.0-alpha/models/guest.json
metadata:
format: json
train_output_data:
name: train_output_data
uri: file:///tmp/train_data_forward
Expand Down
14 changes: 11 additions & 3 deletions schemas/tasks/lr.train.host.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ stage: train
inputs:
parameters:
learning_rate: 0.01
max_iter: 100
max_iter: 5
batch_size: 100
artifacts:
train_data:
Expand All @@ -15,21 +15,29 @@ inputs:
format: csv
num_features: 20
num_samples: 569
id_name: id
delimiter: ","
dtype: float32
validate_data:
name: train_data
uri: file:///Users/sage/proj/FATE/2.0.0-alpha/examples/data/breast_hetero_host.csv
metadata:
format: csv
num_features: 20
num_samples: 569
id_name: id
delimiter: ","
dtype: float32
outputs:
artifacts:
output_model:
name: output_model
uri: file:///tmp/trained_model
uri: file:///Users/sage/proj/FATE/2.0.0-alpha/models/host.json
metadata:
format: json
train_output_data:
name: train_output_data
uri: file:///tmp/train_data_forward
uri: file:///Users/sage/temp/host_output_data
metadata:
format: csv
metrics:
Expand Down

0 comments on commit 9a8b6cd

Please sign in to comment.