From 909c5637f42e868fc1959425d86b53f7c7061910 Mon Sep 17 00:00:00 2001 From: weiwee Date: Fri, 25 Nov 2022 02:21:09 -0800 Subject: [PATCH] chore(component): refact component structure Signed-off-by: weiwee --- python/fate/__init__.py | 1 + python/fate/_info.py | 2 + python/fate/components/__init__.py | 87 +++++++++++++++++++ python/fate/components/components/__init__.py | 2 +- .../components/components/feature_scale.py | 35 ++++---- .../components/{lr.py => hetero_lr.py} | 30 ++++--- .../components/components/intersection.py | 16 ++-- python/fate/components/components/reader.py | 7 +- python/fate/components/cpn.py | 66 ++++++++++---- .../fate/components/entrypoint/component.py | 7 +- python/fate/components/loader/artifact.py | 25 ++++++ python/fate/components/loader/component.py | 4 +- python/fate/components/loader/device.py | 10 +++ python/fate/components/spec/__init__.py | 3 - python/fate/components/spec/artifacts.py | 36 -------- python/fate/components/spec/component.py | 16 ++-- python/fate/components/spec/device.py | 13 +++ python/fate/components/spec/task.py | 25 ++---- python/fate/components/spec/types.py | 36 -------- schemas/tasks/lr.predict.guest.yaml | 3 +- schemas/tasks/lr.predict.host.yaml | 3 +- schemas/tasks/lr.train.arbiter.yaml | 3 +- schemas/tasks/lr.train.guest.yaml | 3 +- schemas/tasks/lr.train.host.yaml | 3 +- 24 files changed, 264 insertions(+), 172 deletions(-) create mode 100644 python/fate/_info.py rename python/fate/components/components/{lr.py => hetero_lr.py} (89%) create mode 100644 python/fate/components/loader/artifact.py create mode 100644 python/fate/components/loader/device.py create mode 100644 python/fate/components/spec/device.py delete mode 100644 python/fate/components/spec/types.py diff --git a/python/fate/__init__.py b/python/fate/__init__.py index e69de29bb2..f61e2ef24d 100644 --- a/python/fate/__init__.py +++ b/python/fate/__init__.py @@ -0,0 +1 @@ +from ._info import __provider__, __version__ diff --git a/python/fate/_info.py b/python/fate/_info.py new file mode 100644 index 0000000000..4e8d3dc297 --- /dev/null +++ b/python/fate/_info.py @@ -0,0 +1,2 @@ +__version__ = "2.0.0.alpha" +__provider__ = "fate" diff --git a/python/fate/components/__init__.py b/python/fate/components/__init__.py index e69de29bb2..1cead0b372 100644 --- a/python/fate/components/__init__.py +++ b/python/fate/components/__init__.py @@ -0,0 +1,87 @@ +from typing import Literal, Type, TypeVar + +from typing_extensions import Annotated + +GUEST = "guest" +HOST = "host" +ARBITER = "arbiter" + +T_ROLE = Literal["guest", "host", "arbiter"] +T_STAGE = Literal["train", "predict", "default"] +T_LABEL = Literal["trainable"] + + +class STAGES: + TRAIN = "train" + PREDICT = "predict" + DEFAULT = "default" + + +class LABELS: + TRAINABLE = "trainable" + + +class OutputAnnotated: + ... + + +class InputAnnotated: + ... + + +T = TypeVar("T") +Output = Annotated[T, OutputAnnotated] +Input = Annotated[T, InputAnnotated] + + +class Artifact: + type: str = "artifact" + """Represents a generic machine learning artifact. + """ + + +class Artifacts: + type: str = "artifacts" + + +class DatasetArtifact(Artifact): + type = "dataset" + """An artifact representing a machine learning dataset. + """ + + +class DatasetArtifacts(Artifacts): + type = "datasets" + + +class ModelArtifact(Artifact): + type = "model" + """An artifact representing a machine learning model. + """ + + +class ModelArtifacts(Artifacts): + type = "models" + artifact_type: Type[Artifact] = ModelArtifact + + +class MetricArtifact(Artifact): + type = "metric" + + +class ClassificationMetrics(Artifact): + """An artifact for storing classification metrics.""" + + type = "classification_metrics" + + +class SlicedClassificationMetrics(Artifact): + """An artifact for storing sliced classification metrics. + + Similar to ``ClassificationMetrics``, tasks using this class are + expected to use log methods of the class to log metrics with the + difference being each log method takes a slice to associate the + ``ClassificationMetrics``. + """ + + type = "sliced_classification_metrics" diff --git a/python/fate/components/components/__init__.py b/python/fate/components/components/__init__.py index b2eecc6619..fcf099a6c1 100644 --- a/python/fate/components/components/__init__.py +++ b/python/fate/components/components/__init__.py @@ -1,6 +1,6 @@ from .feature_scale import feature_scale +from .hetero_lr import hetero_lr from .intersection import intersection -from .lr import hetero_lr from .reader import reader BUILDIN_COMPONENTS = [ diff --git a/python/fate/components/components/feature_scale.py b/python/fate/components/components/feature_scale.py index cb0b354abc..8ef2605d9a 100644 --- a/python/fate/components/components/feature_scale.py +++ b/python/fate/components/components/feature_scale.py @@ -1,25 +1,24 @@ -from fate.components import cpn -from fate.components.spec import ( +from fate.components import ( + GUEST, + HOST, DatasetArtifact, Input, - MetricArtifact, ModelArtifact, Output, - roles, + cpn, ) -from fate.ml.feature_scale import FeatureScale -@cpn.component(roles=[roles.GUEST, roles.HOST], provider="fate", version="2.0.0.alpha") +@cpn.component(roles=[GUEST, HOST]) def feature_scale(ctx, role): ... -@feature_scale.stage("train") -@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[roles.GUEST, roles.HOST]) +@feature_scale.train() +@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[GUEST, HOST]) @cpn.parameter("method", type=str, default="standard", optional=False) -@cpn.artifact("train_output_data", type=Output[DatasetArtifact], roles=[roles.GUEST, roles.HOST]) -@cpn.artifact("output_model", type=Output[ModelArtifact], roles=[roles.GUEST, roles.HOST]) +@cpn.artifact("train_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST]) +@cpn.artifact("output_model", type=Output[ModelArtifact], roles=[GUEST, HOST]) def feature_scale_train( ctx, role, @@ -28,15 +27,13 @@ def feature_scale_train( train_output_data, output_model, ): - train( - ctx, train_data, train_output_data, output_model, method - ) + train(ctx, train_data, train_output_data, output_model, method) -@feature_scale.stage("predict") -@cpn.artifact("input_model", type=Input[ModelArtifact], roles=[roles.GUEST, roles.HOST]) -@cpn.artifact("test_data", type=Input[DatasetArtifact], optional=False, roles=[roles.GUEST, roles.HOST]) -@cpn.artifact("test_output_data", type=Output[DatasetArtifact], roles=[roles.GUEST, roles.HOST]) +@feature_scale.predict() +@cpn.artifact("input_model", type=Input[ModelArtifact], roles=[GUEST, HOST]) +@cpn.artifact("test_data", type=Input[DatasetArtifact], optional=False, roles=[GUEST, HOST]) +@cpn.artifact("test_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST]) def feature_scale_predict( ctx, role, @@ -48,6 +45,8 @@ def feature_scale_predict( def train(ctx, train_data, train_output_data, output_model, method): + from fate.ml.feature_scale import FeatureScale + scaler = FeatureScale(method) with ctx.sub_ctx("train") as sub_ctx: train_data = sub_ctx.reader(train_data).read_dataframe().data.to_local() @@ -62,6 +61,8 @@ def train(ctx, train_data, train_output_data, output_model, method): def predict(ctx, input_model, test_data, test_output_data): + from fate.ml.feature_scale import FeatureScale + with ctx.sub_ctx("predict") as sub_ctx: model = sub_ctx.reader(input_model).read_model() scaler = FeatureScale.from_model(model) diff --git a/python/fate/components/components/lr.py b/python/fate/components/components/hetero_lr.py similarity index 89% rename from python/fate/components/components/lr.py rename to python/fate/components/components/hetero_lr.py index 6dd2f90cd9..9ce258705d 100644 --- a/python/fate/components/components/lr.py +++ b/python/fate/components/components/hetero_lr.py @@ -1,28 +1,30 @@ -from fate.components import cpn -from fate.components.spec import ( +from fate.components import ( + ARBITER, + GUEST, + HOST, DatasetArtifact, Input, MetricArtifact, ModelArtifact, Output, - roles, + cpn, ) -@cpn.component(roles=[roles.GUEST, roles.HOST, roles.ARBITER], provider="fate", version="2.0.0.alpha") +@cpn.component(roles=[GUEST, HOST, ARBITER]) def hetero_lr(ctx, role): ... -@hetero_lr.stage() -@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[roles.GUEST, roles.HOST]) -@cpn.artifact("validate_data", type=Input[DatasetArtifact], optional=True, roles=[roles.GUEST, roles.HOST]) +@hetero_lr.train() +@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[GUEST, HOST]) +@cpn.artifact("validate_data", type=Input[DatasetArtifact], optional=True, roles=[GUEST, HOST]) @cpn.parameter("learning_rate", type=float, default=0.1) @cpn.parameter("max_iter", type=int, default=100) @cpn.parameter("batch_size", type=int, default=100) -@cpn.artifact("train_output_data", type=Output[DatasetArtifact], roles=[roles.GUEST, roles.HOST]) -@cpn.artifact("train_output_metric", type=Output[MetricArtifact], roles=[roles.ARBITER]) -@cpn.artifact("output_model", type=Output[ModelArtifact], roles=[roles.GUEST, roles.HOST]) +@cpn.artifact("train_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST]) +@cpn.artifact("train_output_metric", type=Output[MetricArtifact], roles=[ARBITER]) +@cpn.artifact("output_model", type=Output[ModelArtifact], roles=[GUEST, HOST]) def train( ctx, role, @@ -47,10 +49,10 @@ def train( train_arbiter(ctx, max_iter, train_output_metric) -@hetero_lr.stage() -@cpn.artifact("input_model", type=Input[ModelArtifact], roles=[roles.GUEST, roles.HOST]) -@cpn.artifact("test_data", type=Input[DatasetArtifact], optional=False, roles=[roles.GUEST, roles.HOST]) -@cpn.artifact("test_output_data", type=Output[DatasetArtifact], roles=[roles.GUEST, roles.HOST]) +@hetero_lr.predict() +@cpn.artifact("input_model", type=Input[ModelArtifact], roles=[GUEST, HOST]) +@cpn.artifact("test_data", type=Input[DatasetArtifact], optional=False, roles=[GUEST, HOST]) +@cpn.artifact("test_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST]) def predict( ctx, role, diff --git a/python/fate/components/components/intersection.py b/python/fate/components/components/intersection.py index 5f1cc079b0..5afe9b7adf 100644 --- a/python/fate/components/components/intersection.py +++ b/python/fate/components/components/intersection.py @@ -1,12 +1,10 @@ -from fate.components import cpn -from fate.components.spec import DatasetArtifact, Input, Output, roles -from fate.ml.intersection import RawIntersectionGuest, RawIntersectionHost +from fate.components import GUEST, HOST, DatasetArtifact, Input, Output, cpn -@cpn.component(roles=[roles.GUEST, roles.HOST], provider="fate", version="2.0.0.alpha") -@cpn.artifact("input_data", type=Input[DatasetArtifact], roles=[roles.GUEST, roles.HOST]) +@cpn.component(roles=[GUEST, HOST], provider="fate") +@cpn.artifact("input_data", type=Input[DatasetArtifact], roles=[GUEST, HOST]) @cpn.parameter("method", type=str, default="raw", optional=True) -@cpn.artifact("output_data", type=Output[DatasetArtifact], roles=[roles.GUEST, roles.HOST]) +@cpn.artifact("output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST]) def intersection( ctx, role, @@ -23,13 +21,17 @@ def intersection( def raw_intersect_guest(ctx, input_data, output_data): + from fate.ml.intersection import RawIntersectionGuest + data = ctx.reader(input_data).read_dataframe().data guest_intersect_obj = RawIntersectionGuest() intersect_data = guest_intersect_obj.fit(ctx, data) ctx.writer(output_data).write_dataframe(intersect_data) -def raw_intersect_host(ctx, input_data, output_data): +def raw_intersect_host(ctx, input_data, output_data): + from fate.ml.intersection import RawIntersectionHost + data = ctx.reader(input_data).read_dataframe().data host_intersect_obj = RawIntersectionHost() intersect_data = host_intersect_obj.fit(ctx, data) diff --git a/python/fate/components/components/reader.py b/python/fate/components/components/reader.py index 1f0fc7c61a..8317e9bf70 100644 --- a/python/fate/components/components/reader.py +++ b/python/fate/components/components/reader.py @@ -1,8 +1,7 @@ -from fate.components import cpn -from fate.components.spec import DatasetArtifact, Output, roles +from fate.components import GUEST, HOST, DatasetArtifact, Output, cpn -@cpn.component(roles=[roles.GUEST, roles.HOST], provider="fate", version="2.0.0.alpha") +@cpn.component(roles=[GUEST, HOST]) @cpn.parameter("path", type=str, default=None, optional=False) @cpn.parameter("format", type=str, default="csv", optional=False) @cpn.parameter("id_name", type=str, default="id", optional=True) @@ -10,7 +9,7 @@ @cpn.parameter("label_name", type=str, default=None, optional=True) @cpn.parameter("label_type", type=str, default="float32", optional=True) @cpn.parameter("dtype", type=str, default="float32", optional=True) -@cpn.artifact("output_data", type=Output[DatasetArtifact], roles=[roles.GUEST, roles.HOST]) +@cpn.artifact("output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST]) def reader( ctx, role, diff --git a/python/fate/components/cpn.py b/python/fate/components/cpn.py index a3d7d846de..72e4d2f1bf 100644 --- a/python/fate/components/cpn.py +++ b/python/fate/components/cpn.py @@ -124,9 +124,9 @@ def __init__( self.stages: Dict[str, _Component] = {} - def validate_and_extract_execute_args(self, config): - role = config.role - stage = config.stage + def validate_and_extract_execute_args(self, role, stage, inputs_artifacts, outputs_artifacts, inputs_parameters): + from fate.components.loader.artifact import load_artifact + name_artifact_mapping = {artifact.name: artifact for artifact in self.artifacts} name_parameter_mapping = {parameter.name: parameter for parameter in self.parameters} execute_args = [role] @@ -136,9 +136,9 @@ def validate_and_extract_execute_args(self, config): if (arti.stages is None or stage in arti.stages) and (arti.roles is None or role in arti.roles): # get corresponding applying config if isinstance(arti, _InputArtifactDeclareClass): - artifact_apply = config.inputs.artifacts.get(arg) + artifact_apply = inputs_artifacts.get(arg) elif isinstance(arti, _OutputArtifactDeclareClass): - artifact_apply = config.outputs.artifacts.get(arg) + artifact_apply = outputs_artifacts.get(arg) else: artifact_apply = None @@ -153,7 +153,7 @@ def validate_and_extract_execute_args(self, config): try: # annotated metadata drop in inherite, so pass type as argument here # maybe we could find more elegant way some day - execute_args.append(arti.type.parse_desc(artifact_apply)) + execute_args.append(load_artifact(artifact_apply, arti.type)) except Exception as e: raise ComponentApplyError( f"artifact `{arg}` with applying config `{artifact_apply}` can't apply to `{arti}`" @@ -163,12 +163,12 @@ def validate_and_extract_execute_args(self, config): # arg support to be parameter elif parameter := name_parameter_mapping.get(arg): - parameter_apply = config.inputs.parameters.get(arg) + parameter_apply = inputs_parameters.get(arg) if parameter_apply is None: if not parameter.optional: raise ComponentApplyError(f"parameter `{arg}` required, declare: `{parameter}`") else: - execute_args.append(parameter_apply) + execute_args.append(parameter.default) else: if type(parameter_apply) != parameter.type: raise ComponentApplyError( @@ -189,7 +189,7 @@ def execute(self, ctx, *args): def get_artifacts(self): mapping = {artifact.name: artifact for artifact in self.artifacts} - for stage_name, stage_cpn in self.stages.items(): + for _, stage_cpn in self.stages.items(): for artifact_name, artifact in stage_cpn.get_artifacts().items(): # update or merge if artifact_name not in mapping: @@ -215,7 +215,7 @@ def get_artifacts(self): def get_parameters(self): mapping = {parameter.name: parameter for parameter in self.parameters} - for stage_name, stage_cpn in self.stages.items(): + for _, stage_cpn in self.stages.items(): for parameter_name, parameter in stage_cpn.get_parameters().items(): # update or error if parameter_name not in mapping: @@ -237,6 +237,7 @@ def get_parameters(self): return mapping def dict(self): + from fate.components import InputAnnotated, OutputAnnotated from fate.components.spec.component import ( ArtifactSpec, ComponentSpec, @@ -245,13 +246,12 @@ def dict(self): OutputDefinitionsSpec, ParameterSpec, ) - from fate.components.spec.types import InputAnnotated, OutputAnnotated input_artifacts = {} output_artifacts = {} for artifact_name, artifact in self.get_artifacts().items(): annotated = getattr(artifact.type, "__metadata__", [None])[0] - roles = getattr(artifact, "roles") or self.roles + roles = artifact.roles or self.roles if annotated == OutputAnnotated: output_artifacts[artifact_name] = ArtifactSpec( type=artifact.type.type, optional=artifact.optional, roles=roles, stages=artifact.stages @@ -266,7 +266,10 @@ def dict(self): input_parameters = {} for parameter_name, parameter in self.get_parameters().items(): input_parameters[parameter_name] = ParameterSpec( - type=parameter.type.__name__, default=parameter.default, optional=parameter.optional + type=parameter.type.__name__, + default=parameter.default, + optional=parameter.optional, + description=parameter.desc, ) input_definition = InputDefinitionsSpec(parameters=input_parameters, artifacts=input_artifacts) @@ -299,8 +302,20 @@ def dump_yaml(self, stream=None): if inefficient: return stream.getvalue() + def predict(self, roles=[], provider: Optional[str] = None, version: Optional[str] = None, description=None): + from fate.components import STAGES + + return self.stage( + roles=roles, name=STAGES.PREDICT, provider=provider, version=version, description=description + ) + + def train(self, roles=[], provider: Optional[str] = None, version: Optional[str] = None, description=None): + from fate.components import STAGES + + return self.stage(roles=roles, name=STAGES.TRAIN, provider=provider, version=version, description=description) + def stage( - self, name=None, roles=[], provider: Optional[str] = None, version: Optional[str] = None, description=None + self, roles=[], name=None, provider: Optional[str] = None, version: Optional[str] = None, description=None ): r"""Creates a new stage component with :class:`_Component` and uses the decorated function as callback. This will also automatically attach all decorated @@ -327,7 +342,13 @@ def wrap(f): return wrap -def component(name=None, roles=[], provider="fate", version="2.0.0.alpha", description=None): +def component( + roles: list, + name: Optional[str] = None, + provider: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, +): r"""Creates a new :class:`_Component` and uses the decorated function as callback. This will also automatically attach all decorated :func:`artifact`\s and :func:`parameter`\s as parameters to the component execution. @@ -342,12 +363,20 @@ def component(name=None, roles=[], provider="fate", version="2.0.0.alpha", descr :param name: the name of the component. This defaults to the function name. """ + from fate import __provider__, __version__ + + if version is None: + version = __version__ + if provider is None: + provider = __provider__ return _component( name=name, roles=roles, provider=provider, version=version, description=description, is_subcomponent=False ) def _component(name, roles, provider, version, description, is_subcomponent): + from fate.components import STAGES + def decorator(f): cpn_name = name or f.__name__.lower() if isinstance(f, _Component): @@ -368,7 +397,7 @@ def decorator(f): if is_subcomponent: artifact.stages = [cpn_name] else: - artifact.stages = ["default"] + artifact.stages = [STAGES.DEFAULT] desc = description if desc is None: desc = inspect.getdoc(f) @@ -414,7 +443,7 @@ def __str__(self) -> str: def _create_artifact_declare_class(name, type, roles, desc, optional): - from fate.components.spec.types import InputAnnotated, OutputAnnotated + from fate.components import InputAnnotated, OutputAnnotated annotates = getattr(type, "__metadata__", [None]) if OutputAnnotated in annotates: @@ -446,12 +475,13 @@ def __init__(self, name, type, default, optional, desc) -> None: self.type = type self.default = default self.optional = optional + self.desc = desc def __str__(self) -> str: return f"Parameter" -def parameter(name, type, default=None, optional=True, desc=None): +def parameter(name, type, default=None, optional=True, desc=""): """attaches an parameter to the component.""" def decorator(f): diff --git a/python/fate/components/entrypoint/component.py b/python/fate/components/entrypoint/component.py index b1a9f460e2..db1651c35c 100644 --- a/python/fate/components/entrypoint/component.py +++ b/python/fate/components/entrypoint/component.py @@ -5,6 +5,7 @@ from fate.arch.context import Context from fate.components.loader.component import load_component from fate.components.loader.computing import load_computing +from fate.components.loader.device import load_device from fate.components.loader.federation import load_federation from fate.components.loader.mlmd import load_mlmd from fate.components.spec.task import TaskConfigSpec @@ -17,7 +18,7 @@ def execute_component(config: TaskConfigSpec): mlmd = load_mlmd(config.conf.mlmd, context_name) computing = load_computing(config.conf.computing) federation = load_federation(config.conf.federation, computing) - device = config.conf.get_device() + device = load_device(config.conf.device) ctx = Context( context_name=context_name, device=device, @@ -36,7 +37,9 @@ def execute_component(config: TaskConfigSpec): raise ValueError(f"stage `{stage}` for component `{component.name}` not supported") else: component = component.stages[config.stage] - args = component.validate_and_extract_execute_args(config) + args = component.validate_and_extract_execute_args( + config.role, config.stage, config.inputs.artifacts, config.outputs.artifacts, config.inputs.parameters + ) component.execute(ctx, *args) except Exception as e: tb = traceback.format_exc() diff --git a/python/fate/components/loader/artifact.py b/python/fate/components/loader/artifact.py new file mode 100644 index 0000000000..4b8e6dee1e --- /dev/null +++ b/python/fate/components/loader/artifact.py @@ -0,0 +1,25 @@ +def load_artifact(data, artifact_type): + from fate.components.spec.artifacts import ( + Artifact, + Artifacts, + DatasetArtifact, + DatasetArtifacts, + MetricArtifact, + ModelArtifact, + ModelArtifacts, + ) + + if isinstance(data, list): + if artifact_type == DatasetArtifacts: + return DatasetArtifacts([DatasetArtifact(name=d.name, uri=d.uri, metadata=d.metadata) for d in data]) + if artifact_type == ModelArtifacts: + return ModelArtifacts([ModelArtifact(name=d.name, uri=d.uri, metadata=d.metadata) for d in data]) + return Artifacts([Artifact(name=d.name, uri=d.uri, metadata=d.metadata) for d in data]) + else: + if artifact_type == DatasetArtifact: + return DatasetArtifact(name=data.name, uri=data.uri, metadata=data.metadata) + if artifact_type == ModelArtifact: + return ModelArtifact(name=data.name, uri=data.uri, metadata=data.metadata) + if artifact_type == MetricArtifact: + return MetricArtifact(name=data.name, uri=data.uri, metadata=data.metadata) + return Artifact(name=data.name, uri=data.uri, metadata=data.metadata) diff --git a/python/fate/components/loader/component.py b/python/fate/components/loader/component.py index 7cc343ff04..042cdc692e 100644 --- a/python/fate/components/loader/component.py +++ b/python/fate/components/loader/component.py @@ -33,14 +33,14 @@ def list_components(): import pkg_resources from fate.components.components import BUILDIN_COMPONENTS - buildin_components = list(BUILDIN_COMPONENTS.keys()) + buildin_components = [c.name for c in BUILDIN_COMPONENTS] third_parties_components = [] for cpn_ep in pkg_resources.iter_entry_points(group="fate.ext.component"): try: candidate_cpn = cpn_ep.load() candidate_cpn_name = candidate_cpn.name - third_parties_components.append(candidate_cpn_name) + third_parties_components.append([candidate_cpn_name]) except Exception as e: logger.warning( f"register cpn from entrypoint(named={cpn_ep.name}, module={cpn_ep.module_name}) failed: {e}" diff --git a/python/fate/components/loader/device.py b/python/fate/components/loader/device.py new file mode 100644 index 0000000000..ca7453b95b --- /dev/null +++ b/python/fate/components/loader/device.py @@ -0,0 +1,10 @@ +def load_device(device_spec): + from fate.arch.unify import device + from fate.components.spec.device import CPUSpec, GPUSpec + + if isinstance(device_spec, CPUSpec): + return device.CPU + + if isinstance(device_spec, GPUSpec): + return device.CUDA + raise ValueError(f"device `{device_spec}` not implemeted yet") diff --git a/python/fate/components/spec/__init__.py b/python/fate/components/spec/__init__.py index b240c35501..e69de29bb2 100644 --- a/python/fate/components/spec/__init__.py +++ b/python/fate/components/spec/__init__.py @@ -1,3 +0,0 @@ -from .artifacts import * -from .component import * -from .types import * diff --git a/python/fate/components/spec/artifacts.py b/python/fate/components/spec/artifacts.py index 9493728f4f..81e9197eba 100644 --- a/python/fate/components/spec/artifacts.py +++ b/python/fate/components/spec/artifacts.py @@ -18,8 +18,6 @@ from typing import Dict, List, Optional, Type -from .types import Input, Output - class Artifact: type: str = "artifact" @@ -57,32 +55,6 @@ def __repr__(self) -> str: def parse_desc(cls, desc): return cls(uri=desc.uri, name=desc.name, metadata=desc.metadata) - @property - def path(self) -> str: - return self._get_path() - - @path.setter - def path(self, path: str) -> None: - self._set_path(path) - - def _get_path(self) -> Optional[str]: - if self.uri.startswith("gs://"): - return _GCS_LOCAL_MOUNT_PREFIX + self.uri[len("gs://") :] - elif self.uri.startswith("minio://"): - return _MINIO_LOCAL_MOUNT_PREFIX + self.uri[len("minio://") :] - elif self.uri.startswith("s3://"): - return _S3_LOCAL_MOUNT_PREFIX + self.uri[len("s3://") :] - return None - - def _set_path(self, path: str) -> None: - if path.startswith(_GCS_LOCAL_MOUNT_PREFIX): - path = "gs://" + path[len(_GCS_LOCAL_MOUNT_PREFIX) :] - elif path.startswith(_MINIO_LOCAL_MOUNT_PREFIX): - path = "minio://" + path[len(_MINIO_LOCAL_MOUNT_PREFIX) :] - elif path.startswith(_S3_LOCAL_MOUNT_PREFIX): - path = "s3://" + path[len(_S3_LOCAL_MOUNT_PREFIX) :] - self.uri = path - class Artifacts: type: str @@ -455,11 +427,3 @@ def load_confusion_matrix(self, slice: str, categories: List[str], matrix: List[ self._upsert_classification_metrics_for_slice(slice) self._sliced_metrics[slice].log_confusion_matrix_cell(categories, matrix) self._update_metadata(slice) - - -TrainData = Input[DatasetArtifact] -ValidateData = Input[DatasetArtifact] -TestData = Input[DatasetArtifact] -TrainOutputData = Output[DatasetArtifact] -TestOutputData = Output[DatasetArtifact] -Metrics = Output[MetricArtifact] diff --git a/python/fate/components/spec/component.py b/python/fate/components/spec/component.py index ab7000055b..e03ad743ce 100644 --- a/python/fate/components/spec/component.py +++ b/python/fate/components/spec/component.py @@ -1,23 +1,21 @@ -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Optional +from fate.components import T_LABEL, T_ROLE, T_STAGE from pydantic import BaseModel -roles = Literal["guest", "host", "arbiter"] -stages = Literal["train", "predict", "default"] -labels = Literal["trainable"] - class ParameterSpec(BaseModel): type: str default: Any optional: bool + description: str = "" class ArtifactSpec(BaseModel): type: str optional: bool - stages: Optional[List[stages]] - roles: List[roles] + stages: Optional[List[T_STAGE]] + roles: List[T_ROLE] class InputDefinitionsSpec(BaseModel): @@ -34,8 +32,8 @@ class ComponentSpec(BaseModel): description: str provider: str version: str - labels: List[labels] - roles: List[roles] + labels: List[T_LABEL] + roles: List[T_ROLE] input_definitions: InputDefinitionsSpec output_definitions: OutputDefinitionsSpec diff --git a/python/fate/components/spec/device.py b/python/fate/components/spec/device.py new file mode 100644 index 0000000000..3c0fd86411 --- /dev/null +++ b/python/fate/components/spec/device.py @@ -0,0 +1,13 @@ +from typing import Literal + +import pydantic + + +class CPUSpec(pydantic.BaseModel): + type: Literal["CPU"] + metadata: dict = {} + + +class GPUSpec(pydantic.BaseModel): + type: Literal["GPU"] + metadata: dict = {} diff --git a/python/fate/components/spec/task.py b/python/fate/components/spec/task.py index 241d753ce5..5e16a6875c 100644 --- a/python/fate/components/spec/task.py +++ b/python/fate/components/spec/task.py @@ -1,36 +1,25 @@ -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Optional, Union import pydantic -from fate.components.spec.computing import ( - EggrollComputingSpec, - SparkComputingSpec, - StandaloneComputingSpec, -) -from fate.components.spec.federation import ( + +from .computing import EggrollComputingSpec, SparkComputingSpec, StandaloneComputingSpec +from .device import CPUSpec, GPUSpec +from .federation import ( EggrollFederationSpec, RabbitMQFederationSpec, StandaloneFederationSpec, ) -from fate.components.spec.mlmd import CustomMLMDSpec, FlowMLMDSpec, PipelineMLMDSpec - from .logger import CustomLogger, FlowLogger, PipelineLogger +from .mlmd import CustomMLMDSpec, FlowMLMDSpec, PipelineMLMDSpec class TaskConfSpec(pydantic.BaseModel): - device: Literal["CPU", "GPU"] + device: Union[CPUSpec, GPUSpec] computing: Union[StandaloneComputingSpec, EggrollComputingSpec, SparkComputingSpec] federation: Union[StandaloneFederationSpec, EggrollFederationSpec, RabbitMQFederationSpec] logger: Union[PipelineLogger, FlowLogger, CustomLogger] mlmd: Union[PipelineMLMDSpec, FlowMLMDSpec, CustomMLMDSpec] - def get_device(self): - from fate.arch.unify import device - - for dev in device: - if dev.name == self.device.strip().upper(): - return dev - raise ValueError(f"should be one of {[dev.name for dev in device]}") - class ArtifactSpec(pydantic.BaseModel): name: str diff --git a/python/fate/components/spec/types.py b/python/fate/components/spec/types.py deleted file mode 100644 index 75cb25078b..0000000000 --- a/python/fate/components/spec/types.py +++ /dev/null @@ -1,36 +0,0 @@ -import enum -from typing import List, TypeVar - -from typing_extensions import Annotated - - -class OutputAnnotated: - ... - - -class InputAnnotated: - ... - - -T = TypeVar("T") -Output = Annotated[T, OutputAnnotated] -Input = Annotated[T, InputAnnotated] - - -class roles(str, enum.Enum): - GUEST = "guest" - HOST = "host" - ARBITER = "arbiter" - - @classmethod - def get_all(cls) -> List["roles"]: - return [roles.GUEST, roles.HOST, roles.ARBITER] - - -class stages(str, enum.Enum): - TRAIN = "train" - PREDICT = "predict" - - -class labels(str, enum.Enum): - TRAINABLE = "trainable" diff --git a/schemas/tasks/lr.predict.guest.yaml b/schemas/tasks/lr.predict.guest.yaml index 6dca168791..76abd93da2 100644 --- a/schemas/tasks/lr.predict.guest.yaml +++ b/schemas/tasks/lr.predict.guest.yaml @@ -39,7 +39,8 @@ conf: basepath: /Users/sage/alpha/logs level: DEBUG debug_mode: true - device: CPU + device: + type: CPU computing: type: standalone metadata: diff --git a/schemas/tasks/lr.predict.host.yaml b/schemas/tasks/lr.predict.host.yaml index 02c36bf2f2..92108ae2d5 100644 --- a/schemas/tasks/lr.predict.host.yaml +++ b/schemas/tasks/lr.predict.host.yaml @@ -37,7 +37,8 @@ conf: basepath: /Users/sage/alpha/logs level: DEBUG debug_mode: true - device: CPU + device: + type: CPU computing: type: standalone metadata: diff --git a/schemas/tasks/lr.train.arbiter.yaml b/schemas/tasks/lr.train.arbiter.yaml index 40e87392e1..61b0a81ec6 100644 --- a/schemas/tasks/lr.train.arbiter.yaml +++ b/schemas/tasks/lr.train.arbiter.yaml @@ -26,7 +26,8 @@ conf: basepath: /Users/sage/alpha/logs level: DEBUG debug_mode: true - device: CPU + device: + type: CPU computing: type: standalone metadata: diff --git a/schemas/tasks/lr.train.guest.yaml b/schemas/tasks/lr.train.guest.yaml index 1b75b915c8..d6fd1c8839 100644 --- a/schemas/tasks/lr.train.guest.yaml +++ b/schemas/tasks/lr.train.guest.yaml @@ -55,7 +55,8 @@ conf: basepath: /Users/sage/alpha/logs level: DEBUG debug_mode: true - device: CPU + device: + type: CPU computing: type: standalone metadata: diff --git a/schemas/tasks/lr.train.host.yaml b/schemas/tasks/lr.train.host.yaml index db29340b5d..576bb43b7e 100644 --- a/schemas/tasks/lr.train.host.yaml +++ b/schemas/tasks/lr.train.host.yaml @@ -51,7 +51,8 @@ conf: basepath: /Users/sage/alpha/logs level: DEBUG debug_mode: true - device: CPU + device: + type: CPU computing: type: standalone metadata: