Skip to content

Commit

Permalink
feat(component): add parameters validate
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Dec 12, 2022
1 parent c541c67 commit 5358bce
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 10 deletions.
9 changes: 6 additions & 3 deletions python/fate/components/components/hetero_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Output,
Role,
cpn,
params,
)


Expand All @@ -20,9 +21,11 @@ def hetero_lr(ctx, role):
@hetero_lr.train()
@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[GUEST, HOST])
@cpn.artifact("validate_data", type=Input[DatasetArtifact], optional=True, roles=[GUEST, HOST])
@cpn.parameter("learning_rate", type=float, default=0.1)
@cpn.parameter("max_iter", type=int, default=100)
@cpn.parameter("batch_size", type=int, default=100)
@cpn.parameter("learning_rate", type=params.ConFloat(gt=0.0), default=0.1, desc="learning rate")
@cpn.parameter("max_iter", type=params.ConInt(gt=0), default=100)
@cpn.parameter(
"batch_size", type=params.ConInt(), default=100, desc="batch size, value less or equals to 0 means full batch"
)
@cpn.artifact("train_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST])
@cpn.artifact("train_output_metric", type=Output[MetricArtifact], roles=[ARBITER])
@cpn.artifact("output_model", type=Output[ModelArtifact], roles=[GUEST, HOST])
Expand Down
12 changes: 11 additions & 1 deletion python/fate/components/cpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,19 @@ def dict(self):
raise ValueError(f"bad artifact: {artifact}")

input_parameters = {}
from fate.components.params import Parameter

for parameter_name, parameter in self.get_parameters().items():
if isinstance(parameter.type, Parameter): # recomanded
type_name = type(parameter.type).__name__
type_meta = parameter.type.dict()
else:
type_name = parameter.type.__name__
type_meta = {}

input_parameters[parameter_name] = ParameterSpec(
type=parameter.type.__name__,
type=type_name,
type_meta=type_meta,
default=parameter.default,
optional=parameter.optional,
description=parameter.desc,
Expand Down
12 changes: 6 additions & 6 deletions python/fate/components/entrypoint/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict

from fate.arch.context import Context
from fate.components import params
from fate.components.cpn import (
ComponentApplyError,
_Component,
Expand Down Expand Up @@ -98,12 +99,11 @@ def parse_input_parameters(mlmd: MLMD, cpn: _Component, input_parameters: Dict[s
mlmd.io.log_input_parameter(parameter.name, parameter.default)
else:
# TODO: enhance type validate
if type(parameter_apply) != parameter.type:
raise ComponentApplyError(
f"parameter `{arg}` with applying config `{parameter_apply}` can't apply to `{parameter}`"
f": {type(parameter_apply)} != {parameter.type}"
)
execute_parameters[parameter.name] = parameter_apply
try:
value = params.parse(parameter.type, parameter_apply)
except Exception as e:
raise ComponentApplyError(f"apply value `{parameter_apply}` to parameter `{arg}` failed:\n{e}")
execute_parameters[parameter.name] = value
mlmd.io.log_input_parameter(parameter.name, parameter_apply)
return execute_parameters

Expand Down
59 changes: 59 additions & 0 deletions python/fate/components/params/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pydantic


class Parameter:
def dict(self):
raise NotImplementedError()


class ConInt(Parameter):
def __init__(self, gt: int = None, ge: int = None, lt: int = None, le: int = None) -> None:
self.gt = gt
self.ge = ge
self.lt = lt
self.le = le

def parse(self, obj):
return pydantic.parse_obj_as(pydantic.conint(gt=self.gt, ge=self.ge, lt=self.lt, le=self.le), obj)

def dict(self):
meta = {}
if self.gt is not None:
meta["gt"] = self.gt
if self.ge is not None:
meta["ge"] = self.ge
if self.lt is not None:
meta["lt"] = self.lt
if self.le is not None:
meta["le"] = self.le
return meta


class ConFloat(Parameter):
def __init__(self, gt: float = None, ge: float = None, lt: float = None, le: float = None) -> None:
self.gt = gt
self.ge = ge
self.lt = lt
self.le = le

def parse(self, obj):
return pydantic.parse_obj_as(pydantic.confloat(gt=self.gt, ge=self.ge, lt=self.lt, le=self.le), obj)

def dict(self):
meta = {}
if self.gt is not None:
meta["gt"] = self.gt
if self.ge is not None:
meta["ge"] = self.ge
if self.lt is not None:
meta["lt"] = self.lt
if self.le is not None:
meta["le"] = self.le
return meta


def parse(parameter_type, obj):
if isinstance(parameter_type, Parameter) and hasattr(parameter_type, "parse"):
return getattr(parameter_type, "parse")(obj)
else:
return pydantic.parse_obj_as(parameter_type, obj)
1 change: 1 addition & 0 deletions python/fate/components/spec/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class ParameterSpec(BaseModel):
default: Any
optional: bool
description: str = ""
type_meta: dict = {}


class ArtifactSpec(BaseModel):
Expand Down

0 comments on commit 5358bce

Please sign in to comment.