Skip to content

Commit

Permalink
chore(component): refact component structure
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Nov 25, 2022
1 parent 8f64b04 commit 909c563
Show file tree
Hide file tree
Showing 24 changed files with 264 additions and 172 deletions.
1 change: 1 addition & 0 deletions python/fate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._info import __provider__, __version__
2 changes: 2 additions & 0 deletions python/fate/_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__version__ = "2.0.0.alpha"
__provider__ = "fate"
87 changes: 87 additions & 0 deletions python/fate/components/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Literal, Type, TypeVar

from typing_extensions import Annotated

GUEST = "guest"
HOST = "host"
ARBITER = "arbiter"

T_ROLE = Literal["guest", "host", "arbiter"]
T_STAGE = Literal["train", "predict", "default"]
T_LABEL = Literal["trainable"]


class STAGES:
TRAIN = "train"
PREDICT = "predict"
DEFAULT = "default"


class LABELS:
TRAINABLE = "trainable"


class OutputAnnotated:
...


class InputAnnotated:
...


T = TypeVar("T")
Output = Annotated[T, OutputAnnotated]
Input = Annotated[T, InputAnnotated]


class Artifact:
type: str = "artifact"
"""Represents a generic machine learning artifact.
"""


class Artifacts:
type: str = "artifacts"


class DatasetArtifact(Artifact):
type = "dataset"
"""An artifact representing a machine learning dataset.
"""


class DatasetArtifacts(Artifacts):
type = "datasets"


class ModelArtifact(Artifact):
type = "model"
"""An artifact representing a machine learning model.
"""


class ModelArtifacts(Artifacts):
type = "models"
artifact_type: Type[Artifact] = ModelArtifact


class MetricArtifact(Artifact):
type = "metric"


class ClassificationMetrics(Artifact):
"""An artifact for storing classification metrics."""

type = "classification_metrics"


class SlicedClassificationMetrics(Artifact):
"""An artifact for storing sliced classification metrics.
Similar to ``ClassificationMetrics``, tasks using this class are
expected to use log methods of the class to log metrics with the
difference being each log method takes a slice to associate the
``ClassificationMetrics``.
"""

type = "sliced_classification_metrics"
2 changes: 1 addition & 1 deletion python/fate/components/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .feature_scale import feature_scale
from .hetero_lr import hetero_lr
from .intersection import intersection
from .lr import hetero_lr
from .reader import reader

BUILDIN_COMPONENTS = [
Expand Down
35 changes: 18 additions & 17 deletions python/fate/components/components/feature_scale.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
from fate.components import cpn
from fate.components.spec import (
from fate.components import (
GUEST,
HOST,
DatasetArtifact,
Input,
MetricArtifact,
ModelArtifact,
Output,
roles,
cpn,
)
from fate.ml.feature_scale import FeatureScale


@cpn.component(roles=[roles.GUEST, roles.HOST], provider="fate", version="2.0.0.alpha")
@cpn.component(roles=[GUEST, HOST])
def feature_scale(ctx, role):
...


@feature_scale.stage("train")
@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[roles.GUEST, roles.HOST])
@feature_scale.train()
@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[GUEST, HOST])
@cpn.parameter("method", type=str, default="standard", optional=False)
@cpn.artifact("train_output_data", type=Output[DatasetArtifact], roles=[roles.GUEST, roles.HOST])
@cpn.artifact("output_model", type=Output[ModelArtifact], roles=[roles.GUEST, roles.HOST])
@cpn.artifact("train_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST])
@cpn.artifact("output_model", type=Output[ModelArtifact], roles=[GUEST, HOST])
def feature_scale_train(
ctx,
role,
Expand All @@ -28,15 +27,13 @@ def feature_scale_train(
train_output_data,
output_model,
):
train(
ctx, train_data, train_output_data, output_model, method
)
train(ctx, train_data, train_output_data, output_model, method)


@feature_scale.stage("predict")
@cpn.artifact("input_model", type=Input[ModelArtifact], roles=[roles.GUEST, roles.HOST])
@cpn.artifact("test_data", type=Input[DatasetArtifact], optional=False, roles=[roles.GUEST, roles.HOST])
@cpn.artifact("test_output_data", type=Output[DatasetArtifact], roles=[roles.GUEST, roles.HOST])
@feature_scale.predict()
@cpn.artifact("input_model", type=Input[ModelArtifact], roles=[GUEST, HOST])
@cpn.artifact("test_data", type=Input[DatasetArtifact], optional=False, roles=[GUEST, HOST])
@cpn.artifact("test_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST])
def feature_scale_predict(
ctx,
role,
Expand All @@ -48,6 +45,8 @@ def feature_scale_predict(


def train(ctx, train_data, train_output_data, output_model, method):
from fate.ml.feature_scale import FeatureScale

scaler = FeatureScale(method)
with ctx.sub_ctx("train") as sub_ctx:
train_data = sub_ctx.reader(train_data).read_dataframe().data.to_local()
Expand All @@ -62,6 +61,8 @@ def train(ctx, train_data, train_output_data, output_model, method):


def predict(ctx, input_model, test_data, test_output_data):
from fate.ml.feature_scale import FeatureScale

with ctx.sub_ctx("predict") as sub_ctx:
model = sub_ctx.reader(input_model).read_model()
scaler = FeatureScale.from_model(model)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
from fate.components import cpn
from fate.components.spec import (
from fate.components import (
ARBITER,
GUEST,
HOST,
DatasetArtifact,
Input,
MetricArtifact,
ModelArtifact,
Output,
roles,
cpn,
)


@cpn.component(roles=[roles.GUEST, roles.HOST, roles.ARBITER], provider="fate", version="2.0.0.alpha")
@cpn.component(roles=[GUEST, HOST, ARBITER])
def hetero_lr(ctx, role):
...


@hetero_lr.stage()
@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[roles.GUEST, roles.HOST])
@cpn.artifact("validate_data", type=Input[DatasetArtifact], optional=True, roles=[roles.GUEST, roles.HOST])
@hetero_lr.train()
@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[GUEST, HOST])
@cpn.artifact("validate_data", type=Input[DatasetArtifact], optional=True, roles=[GUEST, HOST])
@cpn.parameter("learning_rate", type=float, default=0.1)
@cpn.parameter("max_iter", type=int, default=100)
@cpn.parameter("batch_size", type=int, default=100)
@cpn.artifact("train_output_data", type=Output[DatasetArtifact], roles=[roles.GUEST, roles.HOST])
@cpn.artifact("train_output_metric", type=Output[MetricArtifact], roles=[roles.ARBITER])
@cpn.artifact("output_model", type=Output[ModelArtifact], roles=[roles.GUEST, roles.HOST])
@cpn.artifact("train_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST])
@cpn.artifact("train_output_metric", type=Output[MetricArtifact], roles=[ARBITER])
@cpn.artifact("output_model", type=Output[ModelArtifact], roles=[GUEST, HOST])
def train(
ctx,
role,
Expand All @@ -47,10 +49,10 @@ def train(
train_arbiter(ctx, max_iter, train_output_metric)


@hetero_lr.stage()
@cpn.artifact("input_model", type=Input[ModelArtifact], roles=[roles.GUEST, roles.HOST])
@cpn.artifact("test_data", type=Input[DatasetArtifact], optional=False, roles=[roles.GUEST, roles.HOST])
@cpn.artifact("test_output_data", type=Output[DatasetArtifact], roles=[roles.GUEST, roles.HOST])
@hetero_lr.predict()
@cpn.artifact("input_model", type=Input[ModelArtifact], roles=[GUEST, HOST])
@cpn.artifact("test_data", type=Input[DatasetArtifact], optional=False, roles=[GUEST, HOST])
@cpn.artifact("test_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST])
def predict(
ctx,
role,
Expand Down
16 changes: 9 additions & 7 deletions python/fate/components/components/intersection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from fate.components import cpn
from fate.components.spec import DatasetArtifact, Input, Output, roles
from fate.ml.intersection import RawIntersectionGuest, RawIntersectionHost
from fate.components import GUEST, HOST, DatasetArtifact, Input, Output, cpn


@cpn.component(roles=[roles.GUEST, roles.HOST], provider="fate", version="2.0.0.alpha")
@cpn.artifact("input_data", type=Input[DatasetArtifact], roles=[roles.GUEST, roles.HOST])
@cpn.component(roles=[GUEST, HOST], provider="fate")
@cpn.artifact("input_data", type=Input[DatasetArtifact], roles=[GUEST, HOST])
@cpn.parameter("method", type=str, default="raw", optional=True)
@cpn.artifact("output_data", type=Output[DatasetArtifact], roles=[roles.GUEST, roles.HOST])
@cpn.artifact("output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST])
def intersection(
ctx,
role,
Expand All @@ -23,13 +21,17 @@ def intersection(


def raw_intersect_guest(ctx, input_data, output_data):
from fate.ml.intersection import RawIntersectionGuest

data = ctx.reader(input_data).read_dataframe().data
guest_intersect_obj = RawIntersectionGuest()
intersect_data = guest_intersect_obj.fit(ctx, data)
ctx.writer(output_data).write_dataframe(intersect_data)


def raw_intersect_host(ctx, input_data, output_data):
def raw_intersect_host(ctx, input_data, output_data):
from fate.ml.intersection import RawIntersectionHost

data = ctx.reader(input_data).read_dataframe().data
host_intersect_obj = RawIntersectionHost()
intersect_data = host_intersect_obj.fit(ctx, data)
Expand Down
7 changes: 3 additions & 4 deletions python/fate/components/components/reader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from fate.components import cpn
from fate.components.spec import DatasetArtifact, Output, roles
from fate.components import GUEST, HOST, DatasetArtifact, Output, cpn


@cpn.component(roles=[roles.GUEST, roles.HOST], provider="fate", version="2.0.0.alpha")
@cpn.component(roles=[GUEST, HOST])
@cpn.parameter("path", type=str, default=None, optional=False)
@cpn.parameter("format", type=str, default="csv", optional=False)
@cpn.parameter("id_name", type=str, default="id", optional=True)
@cpn.parameter("delimiter", type=str, default=",", optional=True)
@cpn.parameter("label_name", type=str, default=None, optional=True)
@cpn.parameter("label_type", type=str, default="float32", optional=True)
@cpn.parameter("dtype", type=str, default="float32", optional=True)
@cpn.artifact("output_data", type=Output[DatasetArtifact], roles=[roles.GUEST, roles.HOST])
@cpn.artifact("output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST])
def reader(
ctx,
role,
Expand Down
Loading

0 comments on commit 909c563

Please sign in to comment.